Data Structures and Algorithms - The Segment Tree

Naowal Rahman

Introduction

In the world of computer science, there are data structures that can store data in special ways which go hand in hand with algorithms that traverse these data structures and perform operations on them. While there are many data structures and algorithms that are widely used and heard of, one data structure in particular is especially interestingthe segment tree. This data structure allows fast range queries over an array whilst also being very flexible. Arrays can be modified very easily and quickly with this data structure, and it can solve huge numbers of problems. Furthermore, this data structure especially shines when using arrays of much larger dimensions, mainly due to the fact that it only requires \(\mathcal{O}(n\log n)\) memory. The segment tree is a very useful data structure, and is worth taking a look at.

How It Works

Suppose we needed to efficiently answer sum queries with just one data structure on an array \(\{a_0, a_1,\dots, a_{n-1}\}\), where \(n\) is the length of the array. We should be able to use a segment tree to find the sum of elements between two indexes \(s\) and \(e\), the start index and end index.

The key to this data structure is the word “segment". Essentially, after calculating the sum of all the items in the entire array, it will then be split up into a tree of segments, the first of which are \(\{a_0,\dots, \frac n2\}\) and \(\{\frac n2 + 1, \dots, n-1\}\). The first split is the first half of the array, since the \(e\) index is the half the length of the list. The other split will simply be the second half of the array, since the \(s\) index is the half the length of the array plus \(1\). This is to be continued until it isn’t logically possible. In other words, the segmentation stops when \(s = e\). In this case, we’ll take \(s\) to be index 0 (the start of the array), and \(e\) to be index -1 (the end of the array) This forms a binary tree, as each vertex branching from the root will always have two child vertices, and it also means that the number of vertices will be linear. Visually represented, if the array \(\{-3, 2, 4, 9, 6, -8\}\) was given, the tree would look something like this, where the \(\{0...n-1\}\) segment contains all indexes in the array and holds the the total sum of the array:

Segment tree of \{-3, 2, 4, 9, 6, -8\}

Since segment trees follow this branched structure, whenever the length of the array, \(n\), is not a power of 2, not all levels of the segment tree will be equally or completely filled. In general, however, the height of any segment tree will be \(O(\log n)\) due to the fact that the size of the segments is cut in half for every level that you go down the tree.

Implementation

Now that we’ve seen the logic behind a segment tree, lets implement one in Python to answer our sum queries. After importing necessary modules, and , the first step is to implement a function to get the middle index of the array from the bounds \(s\) and \(e\).

from math import ceil, log2
    
def getMid(s, e): 
    return s + (e - s) // 2 

We can then create a recursive function to get the sum of values in the given range of the array. This function requires a variety of parameters, where is the pointer to the segment tree, is the index of the current node, and are the starting and ending indexes represented by the current node, and and are the starting and ending indices of the queried range.

In this function, if a segment of the given node is a part of the given range, then the sum of the segment should be returned. If the segment of the node is outside the given range, the function returns 0. If a part of the segment overlaps with the given range, then we get the mid of and using the function we defined previously and recursively call this function until one of the previous conditionals become true.

def getSumUtil(st, ss, se, qs, qe, si): 
    if (qs <= ss and qe >= se): 
        return st[si]
    
    if (se < qs or ss > qe): 
        return 0 
    
    mid = getMid(ss, se)
    
    # 2 * si + 1 and 2 * si + 2 are the child nodes of si
    return (getSumUtil(st, ss, mid, qs, qe, 2 * si + 1) + 
           getSumUtil(st, mid + 1, se, qs, qe, 2 * si + 2))

The next step is to define a recursive function that will update nodes that have the given indexes, and , in their range. For this function, two new parameters, and , need to be introduced. The first one, , is the index of the element to be updated. The second one, , is the value to be added to all nodes which have index in their range.

The function will return nothing as a base case if the input index lies outside the range of the given segment. If the input index is in the range of the given node, the we need to update the value of this node and its children. In essence, this function is used to update the segment tree. By recursively calling this function on the child nodes of the current node, which are indicated by indices and , a loop is created and it continues until the input index lies outside the range of the given segment.

def updateValueUtil(st, ss, se, i, diff, si): 
    if (i < ss or i > se): 
        return; 

    st[si] = st[si] + diff

    if se != ss: 
        mid = getMid(ss, se) 
        updateValueUtil(st, ss, mid, i, diff, 2 * si + 1) 
        updateValueUtil(st, mid + 1, se, i, diff, 2 * si + 2)

Using this function, another function can be created that will update a value in the input array and thus the segment tree. Nothing is returned if the input index is invalid. Otherwise, the function will proceed to get the difference between the new and old values and then update the value in the array. Finally, it will update the values of the nodes in the segment tree. This allows for easy update queries within the segment tree. While using utility functions within other functions is an unnecessary part of the implementation, they make it much easier.

def updateValue(arr, st, n, i, new_val): 
    if (i < 0 or i > n): 
        print("invalid input", end = "") 
        return 
    
    diff = new_val - arr[i] 
    arr[i] = new_val
    updateValueUtil(st, 0, n - 1, i, diff, 0)

To answer sum queries, a function will be created that uses the previously define utility function that returns the sum. The point of having this function make use of a utility function is to ensure that there are no erroneous input values. Thus, the code is not only much cleaner but also less prone to error. If there are any values that are erroneous, -1 is returned but otherwise the output of the sum utility function is returned. In essence, the sum of elements in the range of index to is being returned.

def getSum(st, n, qs, qe): 
    if (qs < 0 or qe > n - 1 or qs > qe):
        print("invalid input", end = "") 
        return -1 
    
    return getSumUtil(st, 0, n - 1, qs, qe, 0) 

With all the functions to perform operations on the segment tree finished, the only thing left to do is to actually construct a segment tree. First, a recursive utility function is needed that constructs a segment tree given an array . Here, parameter is the index of the current node in the segment tree . If there is one element in the array, the function will store it in the current node of the segment tree and return the value of the starting index . However, if more than 1 element is present, the function will get the middle index of the array and recursively call the child nodes on the left and right, storing the sum of the values in the given node.

Then, a final constructor function is needed to make use of this utility function. It finds the height of the segment tree, calculates the maximum size, allocates the required memory, and fills the allocated memory with segment tree , returning the segment tree using the constructor utility function.

def constructSTUtil(arr, ss, se, st, si):
    if ss == se: 
        st[si] = arr[ss]
        return arr[ss]
        
    mid = getMid(ss, se) 
    st[si] = (constructSTUtil(arr, ss, mid, st, si * 2 + 1) + 
             constructSTUtil(arr, mid + 1, se, st, si * 2 + 2))
            
    return st[si]
    
def constructST(arr, n):
    x = int(ceil(log2(n)))
    max_size = 2 * int(2 ** x) - 1 
    st = [0] * max_size 
    constructSTUtil(arr, 0, n - 1, st, 0)
    return st 

With that, everything needed to make a segment tree is finished. We can put some driver code need to actually make use of it. First, we build the segment tree from the given array and answer a sum query given a starting and ending index. Then, we can perform update queries on the segment tree by updating a given index and its corresponding nodes, and then finding the new value after it is updated.

if __name__ == "__main__":
    arr = [-3, 2, 4, 9, 6, -8]
    n = len(arr) 
    
    st = constructST(arr, n)
    print("sum: ", getSum(st, n, 2, 5))
    
    updateValue(arr, st, n, 3, 12) 
    print("updated sum: ", getSum(st, n, 2, 5), end = "") 

The final output is \(11\), as this is the sum of the given values, and subsequently \(14\), as it’s the updated sum of the values. This is correct, as the sum of all the numbers from index \(2\) to \(5\) is \(11\), and when index \(3\) is updated to \(12\), the sum increases by 3 to 14.

Conclusion

The segment tree has proved to be very interesting, and its features allow for efficient queries of various types. Its binary tree like methodology allows for time complexities far superior than that of other data structures and algorithms that use loops and perform operations linearly. Hopefully, this introduction to segment trees has shown you a different and unique view at the world of data structures and algorithms.