[Leetcode] Segment Tree

Segment Tree

Video: https://www.youtube.com/watch?v=rYBtViWXYeI

Discrete version of a Segment Tree: A balanced binary tree. O(logn) height given n elements.

Each leaf node (segment) represents an element in the array. Each non-leaf node covers the union of its children’s range.

Operations:

  • build(start, end, vals) -> O(n)
  • update(index, value) -> O(logn)
  • rangeQuery(start, end) -> O(logn + k): where k is the number of reported segments
1
2
3
4
5
6
7
8
nums = [2, 1, 5, 3, 4]
15 (0-4)
/ \
8 7
/ | / |
3 5 3 4
/ |
2 1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SegmentTreeNode:
def __init__(self, start, end, sum, left, right):
self.start = start
self.end = end
self.sum = sum
self.left = left
self.right = right

class SegmentTree:
def buildTree(self, start, end, vals):
if start == end:
return SegmentTreeNode(start, end, vals[start], None, None)
mid = (start + end) // 2
leftTree = self.buildTree(start, mid, vals)
rightTree = self.buildTree(mid + 1, end, vals)
return SegmentTreeNode(start, end, leftTree.sum + rightTree.sum, left, right)

def updateTree(self, root, index, val):
# similar to binary search
if root.start == root.end == index:
root.sum = val
return
mid = (root.start + root.end) // 2
if index <= mid:
self.updateTree(root.left, index, val)
else:
self.updateTree(root.right, index, val)
# !important: don't forget to update root sum
root.sum = root.left.sum + root.right.sum

def querySum(self, root, i, j):
if root.start == i and root.end == j:
return root.sum
mid = (start + end) // 2
if j <= mid:
return querySum(root.left, i, j)
elif i > mid:
return querySum(root.right, i, j)
else:
return querySum(root.left, i, mid) + querySum(root.right, mid + 1, j)

307 Range Sum Query - Mutable

Leetcode: https://leetcode.com/problems/range-sum-query-mutable/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Node:
def __init__(self, start, end, sums, left, right):
self.start = start
self.end = end
self.sum = sums
self.left = left
self.right = right

class NumArray:

def __init__(self, nums: List[int]):
self.root = self.buildTree(0, len(nums) - 1, nums)

def buildTree(self, start, end, nums):
if start == end:
return Node(start, end, nums[start], None, None)
mid = (start + end) // 2
leftTree = self.buildTree(start, mid, nums)
rightTree = self.buildTree(mid + 1, end, nums)
root = Node(start, end, leftTree.sum + rightTree.sum, leftTree, rightTree)
return root

def update(self, index: int, val: int) -> None:
self.updateNode(self.root, index, val)

def updateNode(self, node, index, val):
if index == node.start and index == node.end:
node.sum = val
return
mid = (node.start + node.end) // 2
if index > mid:
self.updateNode(node.right, index, val)
else:
self.updateNode(node.left, index, val)
node.sum = node.left.sum + node.right.sum

def sumRange(self, left: int, right: int) -> int:
return self.sumRangeNode(self.root, left, right)

def sumRangeNode(self, node, left, right):
if node.start == left and node.end == right:
return node.sum

mid = (node.start + node.end) // 2
if right <= mid:
return self.sumRangeNode(node.left, left, right)
elif left > mid:
return self.sumRangeNode(node.right, left, right)
else:
return self.sumRangeNode(node.left, left, mid) + self.sumRangeNode(node.right, mid + 1, right)