int getsum(int l, int r, int s, int t, int p) { // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号 if (l <= s && t <= r) return d[p]; // 当前区间为询问区间的子集时直接返回当前区间的和 int m = s + ((t - s) >> 1), sum = 0; if (l <= m) sum += getsum(l, r, s, m, p * 2); // 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子 if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1); // 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子 return sum;}
Python
def getsum(l, r, s, t, p): # [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号 if l <= s and t <= r: return d[p] # 当前区间为询问区间的子集时直接返回当前区间的和 m = s + ((t - s) >> 1) sum = 0 if l <= m: sum = sum + getsum(l, r, s, m, p * 2) # 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子 if r > m: sum = sum + getsum(l, r, m + 1, t, p * 2 + 1) # 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子 return sum
// [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p// 为当前节点的编号void update(int l, int r, int c, int s, int t, int p) { // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改 if (l <= s && t <= r) { d[p] += (t - s + 1) * c, b[p] += c; return; } int m = s + ((t - s) >> 1); if (b[p] && s != t) { // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值 d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m); b[p * 2] += b[p], b[p * 2 + 1] += b[p]; // 将标记下传给子节点 b[p] = 0; // 清空当前节点的标记 } if (l <= m) update(l, r, c, s, m, p * 2); if (r > m) update(l, r, c, m + 1, t, p * 2 + 1); d[p] = d[p * 2] + d[p * 2 + 1];}
Python
def update(l, r, c, s, t, p): # [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p # 为当前节点的编号 if l <= s and t <= r: d[p] = d[p] + (t - s + 1) * c b[p] = b[p] + c return # 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改 m = s + ((t - s) >> 1) if b[p] and s != t: # 如果当前节点的懒标记非空, 则更新当前节点两个子节点的值和懒标记值 d[p * 2] = d[p * 2] + b[p] * (m - s + 1) d[p * 2 + 1] = d[p * 2 + 1] + b[p] * (t - m) # 将标记下传给子节点 b[p * 2] = b[p * 2] + b[p] b[p * 2 + 1] = b[p * 2 + 1] + b[p] # 清空当前节点的标记 b[p] = 0 if l <= m: update(l, r, c, s, m, p * 2) if r > m: update(l, r, c, m + 1, t, p * 2 + 1) d[p] = d[p * 2] + d[p * 2 + 1]
区间查询(区间求和):
[list2tab]
C++
int getsum(int l, int r, int s, int t, int p) { // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号 if (l <= s && t <= r) return d[p]; // 当前区间为询问区间的子集时直接返回当前区间的和 int m = s + ((t - s) >> 1); if (b[p]) { // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值 d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m); b[p * 2] += b[p], b[p * 2 + 1] += b[p]; // 将标记下传给子节点 b[p] = 0; // 清空当前节点的标记 } int sum = 0; if (l <= m) sum = getsum(l, r, s, m, p * 2); if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1); return sum;}
Python
def getsum(l, r, s, t, p): # [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p为当前节点的编号 if l <= s and t <= r: return d[p] # 当前区间为询问区间的子集时直接返回当前区间的和 m = s + ((t - s) >> 1) if b[p]: # 如果当前节点的懒标记非空, 则更新当前节点两个子节点的值和懒标记值 d[p * 2] = d[p * 2] + b[p] * (m - s + 1) d[p * 2 + 1] = d[p * 2 + 1] + b[p] * (t - m) # 将标记下传给子节点 b[p * 2] = b[p * 2] + b[p] b[p * 2 + 1] = b[p * 2 + 1] + b[p] # 清空当前节点的标记 b[p] = 0 sum = 0 if l <= m: sum = getsum(l, r, s, m, p * 2) if r > m: sum = sum + getsum(l, r, m + 1, t, p * 2 + 1) return sum
如果你是要实现区间修改为某一个值而不是加上某一个值的话,代码如下:
[list2tab]
C++
void update(int l, int r, int c, int s, int t, int p) { if (l <= s && t <= r) { d[p] = (t - s + 1) * c, b[p] = c, v[p] = 1; return; } int m = s + ((t - s) >> 1); // 额外数组储存是否修改值 if (v[p]) { d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m); b[p * 2] = b[p * 2 + 1] = b[p]; v[p * 2] = v[p * 2 + 1] = 1; v[p] = 0; } if (l <= m) update(l, r, c, s, m, p * 2); if (r > m) update(l, r, c, m + 1, t, p * 2 + 1); d[p] = d[p * 2] + d[p * 2 + 1];}int getsum(int l, int r, int s, int t, int p) { if (l <= s && t <= r) return d[p]; int m = s + ((t - s) >> 1); if (v[p]) { d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m); b[p * 2] = b[p * 2 + 1] = b[p]; v[p * 2] = v[p * 2 + 1] = 1; v[p] = 0; } int sum = 0; if (l <= m) sum = getsum(l, r, s, m, p * 2); if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1); return sum;}
Python
def update(l, r, c, s, t, p): if l <= s and t <= r: d[p] = (t - s + 1) * c b[p] = c v[p] = 1 return m = s + ((t - s) >> 1) if v[p]: d[p * 2] = b[p] * (m - s + 1) d[p * 2 + 1] = b[p] * (t - m) b[p * 2] = b[p * 2 + 1] = b[p] v[p * 2] = v[p * 2 + 1] = 1 v[p] = 0 if l <= m: update(l, r, c, s, m, p * 2) if r > m: update(l, r, c, m + 1, t, p * 2 + 1) d[p] = d[p * 2] + d[p * 2 + 1]def getsum(l, r, s, t, p): if l <= s and t <= r: return d[p] m = s + ((t - s) >> 1) if v[p]: d[p * 2] = b[p] * (m - s + 1) d[p * 2 + 1] = b[p] * (t - m) b[p * 2] = b[p * 2 + 1] = b[p] v[p * 2] = v[p * 2 + 1] = 1 v[p] = 0 sum = 0 if l <= m: sum = getsum(l, r, s, m, p * 2) if r > m: sum = sum + getsum(l, r, m + 1, t, p * 2 + 1) return sum
动态开点线段树
前面讲到堆式储存的情况下,需要给线段树开 4n 大小的数组。为了节省空间,我们可以不一次性建好树,而是在最初只建立一个根结点代表整个区间。当我们需要访问某个子区间时,才建立代表这个区间的子结点。这样我们不再使用 2p 和 2p+1 代表 p 结点的儿子,而是用 ls 和 rs 记录儿子的编号。总之,动态开点线段树的核心思想就是:结点只有在有需要的时候才被创建。
单次操作的时间复杂度是不变的,为 O(logn)。由于每次操作都有可能创建并访问全新的一系列结点,因此 m 次单点操作后结点的数量规模是 O(mlogn)。最多也只需要 2n−1 个结点,没有浪费。
单点修改:
// root 表示整棵线段树的根结点;cnt 表示当前结点个数int n, cnt, root;int sum[n * 2], ls[n * 2], rs[n * 2];// 用法:update(root, 1, n, x, f); 其中 x 为待修改节点的编号void update(int& p, int s, int t, int x, int f) { // 引用传参 if (!p) p = ++cnt; // 当结点为空时,创建一个新的结点 if (s == t) { sum[p] += f; return; } int m = s + ((t - s) >> 1); if (x <= m) update(ls[p], s, m, x, f); else update(rs[p], m + 1, t, x, f); sum[p] = sum[ls[p]] + sum[rs[p]]; // pushup}
区间询问:
// 用法:query(root, 1, n, l, r);int query(int p, int s, int t, int l, int r) { if (!p) return 0; // 如果结点为空,返回 0 if (s >= l && t <= r) return sum[p]; int m = s + ((t - s) >> 1), ans = 0; if (l <= m) ans += query(ls[p], s, m, l, r); if (r > m) ans += query(rs[p], m + 1, t, l, r); return ans;}
假设货架上从左到右摆放了 N 种商品,并且依次标号为 1 到 N,其中标号为 i 的商品的价格为 Pi。小 Hi 的每次操作分为两种可能,第一种是修改价格:小 Hi 给出一段区间 [L,R] 和一个新的价格 NewP,所有标号在这段区间中的商品的价格都变成 NewP。第二种操作是询问:小 Hi 给出一段区间 [L,R],而小 Ho 要做的便是计算出所有标号在这段区间中的商品的总价格,然后告诉小 Hi。
参考代码
#include <iostream>int n, a[100005], d[270000], b[270000];void build(int l, int r, int p) { // 建树 if (l == r) { d[p] = a[l]; return; } int m = l + ((r - l) >> 1); build(l, m, p << 1), build(m + 1, r, (p << 1) | 1); d[p] = d[p << 1] + d[(p << 1) | 1];}void update(int l, int r, int c, int s, int t, int p) { // 更新,可以参考前面两个例题 if (l <= s && t <= r) { d[p] = (t - s + 1) * c, b[p] = c; return; } int m = s + ((t - s) >> 1); if (b[p]) { d[p << 1] = b[p] * (m - s + 1), d[(p << 1) | 1] = b[p] * (t - m); b[p << 1] = b[(p << 1) | 1] = b[p]; b[p] = 0; } if (l <= m) update(l, r, c, s, m, p << 1); if (r > m) update(l, r, c, m + 1, t, (p << 1) | 1); d[p] = d[p << 1] + d[(p << 1) | 1];}int getsum(int l, int r, int s, int t, int p) { // 取得答案,和前面一样 if (l <= s && t <= r) return d[p]; int m = s + ((t - s) >> 1); if (b[p]) { d[p << 1] = b[p] * (m - s + 1), d[(p << 1) | 1] = b[p] * (t - m); b[p << 1] = b[(p << 1) | 1] = b[p]; b[p] = 0; } int sum = 0; if (l <= m) sum = getsum(l, r, s, m, p << 1); if (r > m) sum += getsum(l, r, m + 1, t, (p << 1) | 1); return sum;}int main() { std::ios::sync_with_stdio(false); std::cin >> n; for (int i = 1; i <= n; i++) std::cin >> a[i]; build(1, n, 1); int q, i1, i2, i3, i4; std::cin >> q; while (q--) { std::cin >> i1 >> i2 >> i3; if (i1 == 0) std::cout << getsum(i2, i3, 1, n, 1) << std::endl; else std::cin >> i4, update(i2, i3, i4, 1, n, 1); } return 0;}