树状数组、线段树与差分练习
文章目录
重要
树状数组与线段树的适用范围是“单点修改,区间查询”。树状数组更适合用于动态维护前缀和,而线段树还可用于求区间最值
树状数组和线段树都是在线方法,支持动态修改与维护。
题目来自 AcWing
1264. 动态求连续区间和 | 原题链接
这是一道树状数组模板题。树状数组最基础的用法就是动态维护前缀和。下面出现的 lowbit
add
query
操作都是维护树状数组的基本函数。
1import sys
2
3
4def lowbit(x):
5 return x & -x
6
7
8def add(x, val):
9 while x <= n:
10 tree[x] += val
11 x += lowbit(x)
12
13
14def query(x):
15 res = 0
16 while x > 0:
17 res += tree[x]
18 x -= lowbit(x)
19 return res
20
21
22data = sys.stdin.read().strip().split("\n")
23n, m = map(int, data[0].strip().split())
24nums = [0] + list(map(int, data[1].strip().split()))
25tree = [0] * (n + 1)
26for i in range(1, n + 1):
27 add(i, nums[i])
28for i in range(2, m + 2):
29 k, a, b = map(int, data[i].strip().split())
30 if k == 0:
31 print(query(b) - query(a - 1))
32 elif k == 1:
33 add(a, b)
线段树写法:
1import sys
2
3
4class Node:
5
6 def __init__(self, l=0, r=0, sum=0):
7 self.l = l
8 self.r = r
9 self.sum = sum
10
11 def __repr__(self):
12 return f"l={self.l}, r={self.r}, sum={self.sum}"
13
14
15def pushup(u):
16 tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum
17
18
19def build(u, l, r):
20 tree[u].l = l
21 tree[u].r = r
22 if l == r:
23 tree[u] = Node(l, r, nums[r])
24 else:
25 mid = (l + r) >> 1
26 build(u << 1, l, mid)
27 build(u << 1 | 1, mid + 1, r)
28 pushup(u)
29
30
31def query(u, l, r):
32 if tree[u].l >= l and tree[u].r <= r:
33 return tree[u].sum
34 mid = (tree[u].l + tree[u].r) >> 1
35 s = 0
36 if l <= mid:
37 s += query(u << 1, l, r)
38 if r > mid:
39 s += query(u << 1 | 1, l, r)
40 return s
41
42
43def modify(u, x, v):
44 if tree[u].l == tree[u].r:
45 tree[u].sum += v
46 else:
47 mid = (tree[u].l + tree[u].r) >> 1
48 if x <= mid:
49 modify(u << 1, x, v)
50 else:
51 modify(u << 1 | 1, x, v)
52 pushup(u)
53
54
55data = sys.stdin.read().strip().split("\n")
56n, m = map(int, data[0].strip().split())
57nums = (0,) + tuple(map(int, data[1].strip().split()))
58tree = [Node(0, 0, 0) for _ in range(4 * n)]
59build(1, 1, n)
60
61for i in range(2, m + 2):
62 k, a, b = map(int, data[i].strip().split())
63 if k == 0:
64 print(query(1, a, b))
65 else:
66 modify(1, a, b)
pushup
build
query
modify
为线段树的基本操作。如果 Python 版本为 3.7 及以上,可以使用 dataclasses.dataclass
修饰器快捷定义 Node
类。
1import sys
2from dataclasses import dataclass
3
4
5@dataclass
6class Node:
7 l: int
8 r: int
9 sum: int
10
11
12def pushup(u):
13 tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum
14
15
16def build(u, l, r):
17 tree[u].l = l
18 tree[u].r = r
19 if l == r:
20 tree[u] = Node(l, r, nums[r])
21 else:
22 mid = (l + r) >> 1
23 build(u << 1, l, mid)
24 build(u << 1 | 1, mid + 1, r)
25 pushup(u)
26
27
28def query(u, l, r):
29 if tree[u].l >= l and tree[u].r <= r:
30 return tree[u].sum
31 mid = (tree[u].l + tree[u].r) >> 1
32 s = 0
33 if l <= mid:
34 s += query(u << 1, l, r)
35 if r > mid:
36 s += query(u << 1 | 1, l, r)
37 return s
38
39
40def modify(u, x, v):
41 if tree[u].l == tree[u].r:
42 tree[u].sum += v
43 else:
44 mid = (tree[u].l + tree[u].r) >> 1
45 if x <= mid:
46 modify(u << 1, x, v)
47 else:
48 modify(u << 1 | 1, x, v)
49 pushup(u)
50
51
52data = sys.stdin.read().strip().split("\n")
53n, m = map(int, data[0].strip().split())
54nums = (0,) + tuple(map(int, data[1].strip().split()))
55tree = [Node(0, 0, 0) for _ in range(4 * n)]
56build(1, 1, n)
57
58for i in range(2, m + 2):
59 k, a, b = map(int, data[i].strip().split())
60 if k == 0:
61 print(query(1, a, b))
62 else:
63 modify(1, a, b)
线段树使用堆存储,空间占用为 $4n$,且线段树的时空消耗大于树状数组。
1265. 数星星 | 原题链接
输入是按照纵坐标排序,纵坐标相同再按照横坐标排序。因此每当输入一个点时,其纵坐标一定是当前所有点中最大的,而要找出在其左下方的点,只需要找出当前所有点中横坐标小于他的点即可。因此这道题需要动态维护前缀和,使用树状数组解决。
1import sys
2from collections import defaultdict
3
4
5def lowbit(x):
6 return x & -x
7
8
9def add(nums, x, val, n):
10 while x <= n:
11 nums[x] += val
12 x += lowbit(x)
13 return nums
14
15
16def query(nums, x, n):
17 res = 0
18 while x > 0:
19 res += nums[x]
20 x -= lowbit(x)
21 return res
22
23
24MAX_COUNT = 32005
25data = sys.stdin.read().strip().split("\n")
26n = int(data[0].strip())
27ans = defaultdict(int)
28tree = [0] * MAX_COUNT
29for i in range(1, n + 1):
30 x, y = map(int, data[i].strip().split())
31 x += 1
32 tree = add(tree, x, 1, MAX_COUNT)
33 count = query(tree, x, MAX_COUNT)
34 ans[count - 1] += 1
35for i in range(n):
36 print(ans[i])
注意
树状数组的下标从 $1$ 开始,而题目中的坐标从 $0$ 开始,因此需要给所有横坐标加 $1$。
1270. 数列区间最大值 | 原题链接
这道题的数据量太大了,每次查询时在整个数列中两两比较的时间复杂度为 $O(nm)$,即使用分治法比较最大值,时间复杂度也是 $O(mlogn)$。
1import sys
2
3
4def query(x, y):
5 if x == y:
6 return nums[y]
7 else:
8 mid = (x + y) // 2
9 return max(query(x, mid), query(mid + 1, y))
10
11
12data = sys.stdin.read().strip().split("\n")
13n, m = map(int, data[0].strip().split())
14nums = (0,) + tuple(map(int, data[1].strip().split()))
15for i in range(2, m + 2):
16 x, y = map(int, data[i].strip().split())
17 print(query(x, y))
这道题可以用线段树解决,只需要修改 query
操作即可实现求区间最大值的功能。经过线段树预处理后,此算法时间复杂度为 $O(logn)$,但是 Python 写这道题的线段树仍然会超时,加上 lru_cache
还会 MLE……
1#define _CRT_SECURE_NO_WARNINGS
2
3#include <algorithm>
4#include <cstdio>
5
6using namespace std;
7
8struct Node {
9 int l, r, max;
10} tree[400010];
11
12int nums[100010] = { 0 };
13
14void pushup(int u) {
15 tree[u].max = max(tree[u << 1].max, tree[u << 1 | 1].max);
16}
17
18void build(int u, int l, int r) {
19 tree[u].l = l;
20 tree[u].r = r;
21 if (l == r) {
22 tree[u] = { l, r, nums[r] };
23 }
24 else {
25 int mid = (l + r) >> 1;
26 build(u << 1, l, mid);
27 build(u << 1 | 1, mid + 1, r);
28 pushup(u);
29 }
30}
31
32int query(int u, int l, int r) {
33 if (tree[u].l >= l && tree[u].r <= r) {
34 return tree[u].max;
35 }
36 int mid = (tree[u].l + tree[u].r) >> 1;
37 int m = -0x7fffffff;
38 if (l <= mid) {
39 m = max(m, query(u << 1, l, r));
40 }
41 if (r > mid) {
42 m = max(m, query(u << 1 | 1, l, r));
43 }
44 return m;
45}
46
47int main() {
48 int n, m;
49 scanf("%d%d", &n, &m);
50 for (int i = 1; i <= n; i++) {
51 scanf("%d", &nums[i]);
52 }
53 build(1, 1, n);
54 for (int i = 0; i < m; i++) {
55 int x, y;
56 scanf("%d%d", &x, &y);
57 printf("%d\n", query(1, x, y));
58 }
59 return 0;
60}
1215. 小朋友排队 | 原题链接
这道题最基础的做法是冒泡排序记录每个小朋友交换的次数,随后通过交换次数计算不满意程度。这样的时间复杂度是 $O(n^2)$ 必然超时。
1n = int(input().strip())
2nums = list(map(int, input().strip().split()))
3height = [[nums[i], 0] for i in range(n)]
4
5
6def bubble_sort(height):
7 for i in range(1, n):
8 j = i - 1
9 while j >= 0 and height[j][0] > height[j + 1][0]:
10 height[j], height[j + 1] = height[j + 1], height[j]
11 height[j][1] += 1
12 height[j + 1][1] += 1
13 j -= 1
14
15
16bubble_sort(height)
17ans = 0
18for h in height:
19 ans += h[1] * (h[1] + 1) // 2
20print(ans)
假设原数组中有 $k$ 个逆序对,那么要是想通过交换相邻两个元素的方式将其排序,则至少需要交换 $k$ 次。而对于原数列中的任意元素 $a_i$,为了将数列排序,则需要将在 $a_i$ 左侧且大于 $a_i$ 的元素交换到 $a_i$ 右侧,在 $a_i$ 右侧且小于 $a_i$ 的元素交换到 $a_i$ 左侧,因此 $a_i$ 需要交换的次数为 $a_i$ 左侧大于 $a_i$ 的元素数量与 $a_i$ 右侧小于 $a_i$ 的元素数量之和。
为了计算 $a_i$ 的交换次数,可以利用树状数组动态维护 cnt
数组前缀和,其中 cnt[i]
表示在原数组任意位置时 i
元素出现次数。在 $a_i$ 前大于 $a_i$ 的元素数量可以表示为当前已遍历的数组长度减去 cnt
在当前位置的前缀和,而在 $a_i$ 后小于 $a_i$ 的元素数量则可以通过反向遍历原数组时求前缀和得到。
例如对于序列 [1, 3, 5, 4, 2]
,当遍历到 $4$ 时 cnt
数组为 [1, 0, 1, 0, 1]
(1-indexed), 其前缀和数组为 [0, 1, 1, 2, 2, 3]
,在 $4$ 前且大于 $4$ 的元素只有 $5$。
1n = int(input().strip())
2nums = tuple(map(lambda x: int(x) + 1, input().strip().split()))
3
4def lowbit(x):
5 return x & -x
6
7def add(tree, x, val, n):
8 while x <= n:
9 tree[x] += val
10 x += lowbit(x)
11 return tree
12
13def query(tree, x):
14 res = 0
15 while x > 0:
16 res += tree[x]
17 x -= lowbit(x)
18 return res
19
20count = [0] * (n + 1000005)
21tree_prefix = [0] * (n + 1000005)
22for i in range(n):
23 count[i] += i - query(tree_prefix, nums[i])
24 tree_prefix = add(tree_prefix, nums[i], 1, 1000005)
25
26tree_postfix = [0] * (n + 1000005)
27for i in range(n)[::-1]:
28 count[i] += query(tree_postfix, nums[i] - 1)
29 tree_postfix = add(tree_postfix, nums[i], 1, 1000005)
30
31s = 0
32for i in count:
33 s += i * (i + 1) // 2
34print(s)
1228. 油漆面积 | 原题链接
这道题目要求多个矩形面积并,使用到的方法是扫描线法,而扫描线需要用到线段树数据结构,因此需要用扫描线的题几乎都很难,而这道题属于扫描线模板题目。
1from functools import total_ordering
2
3
4@total_ordering
5class Segment:
6 def __init__(self, x=0, y1=0, y2=0, k=0):
7 self.x = x
8 self.y1 = y1
9 self.y2 = y2
10 self.k = k
11
12 def __repr__(self):
13 return f"x={self.x}, y1={self.y1}, y2={self.y2}, k={self.k}"
14
15 def __eq__(self, other):
16 return self.x == other.x
17
18 def __lt__(self, other):
19 return self.x < other.x
20
21
22class Node:
23 def __init__(self, l=0, r=0, cnt=0, len=0):
24 self.l = l
25 self.r = r
26 self.cnt = cnt
27 self.len = len
28
29 def __repr__(self):
30 return f"l={self.l}, r={self.r}, cnt={self.cnt}, len={self.len}"
31
32
33def build(u, l, r):
34 tree[u].l = l
35 tree[u].r = r
36 if l == r:
37 return
38
39 mid = (l + r) >> 1
40 build(u << 1, l, mid)
41 build(u << 1 | 1, mid + 1, r)
42
43
44def pushup(u):
45 if tree[u].cnt > 0:
46 tree[u].len = tree[u].r - tree[u].l + 1
47 elif tree[u].l == tree[u].r:
48 tree[u].len = 0
49 else:
50 tree[u].len = tree[u << 1].len + tree[u << 1 | 1].len
51
52
53def modify(u, l, r, k):
54 if tree[u].l >= l and tree[u].r <= r:
55 tree[u].cnt += k
56 pushup(u)
57 return
58
59 mid = (tree[u].l + tree[u].r) >> 1
60 if l <= mid:
61 modify(u << 1, l, r, k)
62 if r > mid:
63 modify(u << 1 | 1, l, r, k)
64 pushup(u)
65
66
67segments = list()
68n = int(input().strip())
69for _ in range(n):
70 x1, y1, x2, y2 = map(int, input().strip().split())
71 segments.append(Segment(x1, y1, y2, 1))
72 segments.append(Segment(x2, y1, y2, -1))
73segments.sort()
74
75tree = [Node() for _ in range(4 * 100000 + 5)]
76build(1, 0, 100000)
77
78res = 0
79for i in range(len(segments)):
80 if i > 0:
81 res += tree[1].len * (segments[i].x - segments[i - 1].x)
82 modify(1, segments[i].y1, segments[i].y2 - 1, segments[i].k)
83
84print(res)
如果不会就暴力做,下面的代码能通过 3/10 个测试点。
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n = int(data[0].strip())
5area = set()
6for k in range(1, n + 1):
7 x1, y1, x2, y2 = map(int, data[k].strip().split())
8 for i in range(x1, x2):
9 for j in range(y1, y2):
10 area.add((i, j))
11print(len(area))
1237. 螺旋折线 | 原题链接
我最初的想法是让一个点沿着螺旋线走到原点,具体做法是每次移动都让点走向上一个拐点,然后统计距离。
1x, y = map(int, input().strip().split())
2length = 0
3while x != 0 or y != 0:
4 if y >= -x and y <= x:
5 length += x - y
6 y = x
7 if y >= x and y >= -x:
8 length += y + x
9 x = -y
10 if y <= -x and y >= x + 1:
11 length += y - x - 1
12 y = x + 1
13 if y <= x + 1 and y <= -x:
14 length += -x - y
15 x = -y
16
17print(length)
然而这样做会超时。数据范围是 $-10^9\leq X,Y\leq10^9$。如果测试点使用极限数据,那么这段代码需要大约 $O(4\cdot max(X,Y))$ 的时间才能解决。优化的方法是找规律。首先将点转到最近的 $y=-x,y>0$ 拐点上,不难发现对于这些拐点,其到圆点的螺旋折线距离为 $2\sum^{y-1}_{i=1}i$。
1x, y = map(int, input().strip().split())
2length = 0
3if x != 0 or y != 0:
4 if y < -x and y >= x + 1:
5 length += y - x - 1
6 y = x + 1
7 if y <= x + 1 and y <= -x:
8 length += -x - y
9 x = -y
10 if y >= -x and y <= x:
11 length += x - y
12 y = x
13 if y >= x and y >= -x:
14 length += y + x
15 x = -y
16
17print(length + 2 * y * (2 * y - 1))
797. 差分 | 原题链接
一维差分模板题。差分序列 $b$ 与 前缀和序列 $s$ 的关系为 $b_i=s_i-s_{i-1},s_0=0$。
1n, m = map(int, input().strip().split())
2nums = (0,) + tuple(map(int, input().strip().split()))
3diff = [nums[i] - nums[i - 1] for i in range(1, n + 1)] + [0]
4for _ in range(m):
5 l, r, c = map(int, input().strip().split())
6 diff[l - 1] += c
7 diff[r] -= c
8ans = [0]
9for i in range(n):
10 ans.append(ans[i] + diff[i])
11print(*ans[1:])
798. 差分矩阵 | 原题链接
二维差分模板题。根据二维前缀和公式 $s_{i,j}=b_{i,j}+s_{i-1,j}+s_{i,j-1}-s_{i-1,j-1}$ 可以推出 $b_{i,j}=s_{i,j}-s_{i-1,j}-s_{i,j-1}+s_{i-1,j-1}$。对 $s$ 矩阵中 $(x_1,y_1)$ 到 $(x_2,y_2)$ 区间内加上 $c$ 等同于 $b_{x_1,y_1}+c,\space b_{x_1,y_2+1}-c,\space b_{x_2+1,y_1}-c,\space b_{x_2+1,y_2+1}+c$
1n, m, q = map(int, input().strip().split())
2mat = [[0] * (m + 1)]
3for _ in range(n):
4 row = [0] + list(map(int, input().strip().split()))
5 mat.append(row)
6
7diff = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
8for i in range(n):
9 for j in range(m):
10 diff[i][j] = mat[i + 1][j + 1] - mat[i][j + 1] - mat[i + 1][j] + mat[i][j]
11
12for _ in range(q):
13 x1, y1, x2, y2, c = map(int, input().strip().split())
14 diff[x1 - 1][y1 - 1] += c
15 diff[x2][y1 - 1] -= c
16 diff[x1 - 1][y2] -= c
17 diff[x2][y2] += c
18
19ans = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
20for i in range(1, n + 1):
21 for j in range(1, m + 1):
22 ans[i][j] = (
23 diff[i - 1][j - 1] + ans[i - 1][j] + ans[i][j - 1] - ans[i - 1][j - 1]
24 )
25for i in range(1, n + 1):
26 print(*ans[i][1:])
1232. 三体攻击 | 原题链接
这道题的思路不是很难,但是代码写起来很麻烦,而且即使给了 5s/256MB 的时空,代码也依然容易 TLE/MLE。大体思路就是通过三维差分数组记录每次攻击的结果,然后使用二分查找刚好有 0 的三维矩阵。(如果每次攻击完都查找一遍绝对会超时)。下面的代码能过 11/12 个测试点。
1import sys
2
3
4def readints():
5 return list(map(int, sys.stdin.read().split()))
6
7
8data = readints()
9ptr = 0
10A, B, C, m = data[ptr], data[ptr + 1], data[ptr + 2], data[ptr + 3]
11ptr += 4
12d = [0] * (A * B * C)
13for i in range(A * B * C):
14 d[i] = data[ptr]
15 ptr += 1
16
17attacks = []
18for _ in range(m):
19 x1, x2, y1, y2, z1, z2, h = data[ptr : ptr + 7]
20 ptr += 7
21 attacks.append((x1, x2, y1, y2, z1, z2, h))
22
23
24def check(k):
25 diff = [[[0 for _ in range(C + 2)] for _ in range(B + 2)] for _ in range(A + 2)]
26
27 for t in range(k):
28 x1, x2, y1, y2, z1, z2, h = attacks[t]
29 diff[x1][y1][z1] += h
30 diff[x2 + 1][y1][z1] -= h
31 diff[x1][y2 + 1][z1] -= h
32 diff[x1][y1][z2 + 1] -= h
33 diff[x2 + 1][y2 + 1][z1] += h
34 diff[x2 + 1][y1][z2 + 1] += h
35 diff[x1][y2 + 1][z2 + 1] += h
36 diff[x2 + 1][y2 + 1][z2 + 1] -= h
37
38 for i in range(1, A + 1):
39 for j in range(1, B + 1):
40 for k in range(1, C + 1):
41 diff[i][j][k] += (
42 diff[i - 1][j][k]
43 + diff[i][j - 1][k]
44 + diff[i][j][k - 1]
45 - diff[i - 1][j - 1][k]
46 - diff[i - 1][j][k - 1]
47 - diff[i][j - 1][k - 1]
48 + diff[i - 1][j - 1][k - 1]
49 )
50 idx = (i - 1) * B * C + (j - 1) * C + (k - 1)
51 if diff[i][j][k] > d[idx]:
52 return True
53 return False
54
55
56left = 1
57right = m
58ans = m
59
60while left <= right:
61 mid = (left + right) // 2
62 if check(mid):
63 ans = mid
64 right = mid - 1
65 else:
66 left = mid + 1
67
68print(ans)
用 C++ 可以 AC:
1#define _CRT_SECURE_NO_WARNINGS
2
3#include <cstdio>
4#include <cstring>
5
6using namespace std;
7
8const int N = 10000000;
9int s[N] = { 0 }, b[N] = { 0 }, w[N] = { 0 }, sc[N] = { 0 };
10int A, B, C, m;
11
12int get(int i, int j, int k) {
13 return ((i - 1) * B + (j - 1)) * C + (k - 1);
14}
15
16bool check(int m) {
17 memset(b, 0, sizeof(int) * N);
18 int x1, x2, y1, y2, z1, z2, h;
19 for (int i = 0; i < m * 7; i += 7) {
20 x1 = w[i], x2 = w[i + 1];
21 y1 = w[i + 2], y2 = w[i + 3];
22 z1 = w[i + 4], z2 = w[i + 5];
23 h = w[i + 6];
24
25 b[get(x1, y1, z1)] += h;
26 b[get(x2 + 1, y1, z1)] -= h;
27 b[get(x1, y2 + 1, z1)] -= h;
28 b[get(x1, y1, z2 + 1)] -= h;
29 b[get(x2 + 1, y2 + 1, z1)] += h;
30 b[get(x2 + 1, y1, z2 + 1)] += h;
31 b[get(x1, y2 + 1, z2 + 1)] += h;
32 b[get(x2 + 1, y2 + 1, z2 + 1)] -= h;
33 }
34
35 for (int i = 1; i <= A; i++) {
36 for (int j = 1; j <= B; j++) {
37 for (int k = 1; k <= C; k++) {
38 b[get(i, j, k)] += b[get(i - 1, j, k)] + b[get(i, j - 1, k)] + b[get(i, j, k - 1)] - b[get(i - 1, j - 1, k)] - b[get(i - 1, j, k - 1)] - b[get(i, j - 1, k - 1)] + b[get(i - 1, j - 1, k - 1)];
39 if (b[get(i, j, k)] > s[get(i, j, k)]) {
40 return true;
41 }
42 }
43 }
44 }
45 return false;
46}
47
48int main() {
49 scanf("%d %d %d %d", &A, &B, &C, &m);
50 for (int i = 1; i <= A; i++) {
51 for (int j = 1; j <= B; j++) {
52 for (int k = 1; k <= C; k++) {
53 scanf("%d", &s[get(i, j, k)]);
54 }
55 }
56 }
57 for (int i = 0; i < m * 7; i += 7) {
58 scanf("%d %d %d %d %d %d %d", &w[i], &w[i + 1], &w[i + 2], &w[i + 3], &w[i + 4], &w[i + 5], &w[i + 6]);
59 }
60
61 int left = 1, right = m, ans = m;
62 while (left <= right) {
63 int mid = (left + right) >> 1;
64 if (check(mid)) {
65 ans = mid;
66 right = mid - 1;
67 }
68 else {
69 left = mid + 1;
70 }
71 }
72 printf("%d\n", ans);
73 return 0;
74}
因为 $A\time B\time C\leq 10^6$,如果将每一个维度都开 $10^6$ 就会爆内存,因此用一个 get
函数做三维到一维的压缩,这样只用开一个 $10^6$ 的一维数组就可以。因为需要保留一些“余量”,因此实际操作的时候开一个 $10^7$ 的数组。