给定一个整数数组(下标由 0 到 n-1,其中 n 表示数组的规模),以及一个查询列表。每一个查询列表有两个整数
[start, end]
。 对于每个查询,计算出数组中从下标 start 到 end 之间的数的总和,并返回在结果列表中。
注意事项
样例
对于数组
[1,2,7,8,5]
,查询[(1,2),(0,4),(2,4)]
, 返回[9,23,20]
挑战
O(logN) time for each query
代码
- 无需拆分区间
/**
* Definition of Interval:
* public classs Interval {
* int start, end;
* Interval(int start, int end) {
* this.start = start;
* this.end = end;
* }
*/
public class Solution {
/*
* @param A: An integer list
* @param queries: An query list
* @return: The result list
*/
class SegmentTreeNode {
public int start;
public int end;
public long sum;
SegmentTreeNode left;
SegmentTreeNode right;
public SegmentTreeNode(int start, int end, long sum) {
this.start = start;
this.end = end;
this.sum = sum;
this.left = null;
this.right = null;
}
}
public SegmentTreeNode build(int start, int end, int[] A) {
if (start > end) {
return null;
}
if (start == end) {
return new SegmentTreeNode(start, end, A[start]);
}
SegmentTreeNode root = new SegmentTreeNode(start, end, 0);
int mid = start + (end - start) / 2;
root.left = build(start, mid, A);
root.right = build(mid + 1, end, A);
if (root.left != null) {
root.sum += root.left.sum;
}
if (root.right != null) {
root.sum += root.right.sum;
}
return root;
}
public long query(SegmentTreeNode root, int start, int end) {
if (start <= root.start && end >= root.end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
long ans = 0;
if (start <= mid) {
ans += query(root.left, start, end);
}
if (end > mid) {
ans += query(root.right, start, end);
}
return ans;
}
SegmentTreeNode root;
public List<Long> intervalSum(int[] A, List<Interval> queries) {
root = build(0, A.length - 1, A);
List<Long> list = new ArrayList<>();
for (Interval num : queries) {
long res = query(root, num.start, num.end);
list.add(res);
}
return list;
}
}
本题有个 bug 需要注意,如果 query 要求返回 long,那么在 SegmentTreeNode 的定义中要把 sum 声明为 long、
- 手动拆分区间
public class Solution {
/**
*@param A, queries: Given an integer array and an query list
*@return: The result list
*/
class SegmentTreeNode {
public int start, end;
public Long sum;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end, Long sum) {
this.start = start;
this.end = end;
this.sum = sum;
this.left = this.right = null;
}
}
public SegmentTreeNode build(int start, int end, int[] A) {
// write your code here
if(start > end) { // check core case
return null;
}
SegmentTreeNode root = new SegmentTreeNode(start, end, 0L);
if(start != end) {
int mid = (start + end) / 2;
root.left = build(start, mid, A);
root.right = build(mid+1, end, A);
root.sum = root.left.sum + root.right.sum;
} else {
root.sum = Long.valueOf(A[start]);
}
return root;
}
public Long query(SegmentTreeNode root, int start, int end) {
// write your code here
if(start == root.start && root.end == end) { // 相等
return root.sum;
}
int mid = (root.start + root.end)/2;
Long leftsum = 0L, rightsum = 0L;
// 左子区
if(start <= mid) {
if( mid < end) { // 分裂
leftsum = query(root.left, start, mid);
} else { // 包含
leftsum = query(root.left, start, end);
}
}
// 右子区
if(mid < end) { // 分裂 3
if(start <= mid) {
rightsum = query(root.right, mid+1, end);
} else { // 包含
rightsum = query(root.right, start, end);
}
}
// else 就是不相交
return leftsum + rightsum;
}
public ArrayList<Long> intervalSum(int[] A,
ArrayList<Interval> queries) {
// write your code here
SegmentTreeNode root = build(0, A.length - 1, A);
ArrayList ans = new ArrayList<Long>();
for(Interval in : queries) {
ans.add(query(root, in.start, in.end));
}
return ans;
}
}
网友评论