basic
回顾下页式(page-oritented)存储: 有一个page为树的根页; 页中包含多个键以及子页的索引; 子页则负责以父页中某两个相邻键为边界的连续范围内的所有键.
Segment Tree与其类似, 每个节点代表一个区间中的某统计值(如: 最大/小值, 和, 积等), 节点的子节点按序负责父节点区间中的某一子区间(一般均匀分配).
以binary segment tree为例, 每个节点的区间分为两个子区间, 节点值为该区间的和, 示意图如下:
故我们还需要存储每个节点所负责区间的起止位置, 即除了val(此处为sum)还需有 start/end变量.
operations
基于这种结构, 提供的操作(API):
- initialize(datas): 初始化一棵segment tree, 同二叉树的递归构建, 参数应有: datas——源数据,start——负责区间的起始位置,end——负责区间的结束位置.
- update(i, num): 根据位置i和新值num更新datas中的某数, 实质是据此更新segment tree. 更新过程也是递归过程, 将i为当前节点对应区间进行对比以确认是否更新val, 并递归调用i所处的对应子区间的子树进行更新.
- rangeSum(i, j): 给定子区间[i,j], 并不断递归查sum值. [i,j]在segment tree上可能被分为多个子区间进行存储, 则根据所求的值进行迭代并返回, 如求和则将各区间的值求和, 求max则在递归返回时不断比较并返回当前max.
code
以 [LeetCode 307. Range Sum Query - Mutable][2] 为例, Segment Tree的Python实现如下:
class SegmentNode:
def __init__(self, start, end):
self.start, self.end, self.sum = start, end, 0 # the start/end/sum of the interval
self.left, self.right = None, None # left/right interval
class NumArray:
def __init__(self, nums: list):
def buildTree(l, r):
if l > r: # irregular parameters
return None
if l == r: # leaf node
n = SegmentNode(l, r)
n.sum = nums[l]
return n
mid, root = (l + r) // 2, SegmentNode(l, r)
root.left, root.right = buildTree(l, mid), buildTree(mid + 1, r) # recursively build the tree
root.sum = root.left.sum + root.right.sum # update the sum from children
return root
self.root = buildTree(0, len(nums) - 1)
def update(self, i: int, val: int) -> None:
def updateTree(root, i, val):
if root.start == root.end: # the leaf node to update
root.sum = val
return val
mid = (root.start + root.end) // 2 # then recursively update the tree
if i <= mid:
updateTree(root.left, i, val)
else:
updateTree(root.right, i, val)
root.sum = root.left.sum + root.right.sum # update the sum from children
return root.sum
updateTree(self.root, i, val)
def sumRange(self, i: int, j: int) -> int:
def findNode(root, x, y):
if root.start == x and root.end == y: # just the interval
return root.sum
mid = (root.start + root.end) // 2
if y <= mid:
return findNode(root.left, x, y) # the interval belongs to left
if x > mid:
return findNode(root.right, x, y) # the interval belongs to right
return findNode(root.left, x, mid) + findNode(root.right, mid + 1, y) # cross left and right
return findNode(self.root, i, j)
网友评论