树状数组、线段树与差分练习

文章目录

重要

树状数组与线段树的适用范围是“单点修改,区间查询”。树状数组更适合用于动态维护前缀和,而线段树还可用于求区间最值
树状数组和线段树都是在线方法,支持动态修改与维护。

题目来自 AcWing


1264. 动态求连续区间和 | 原题链接

这是一道树状数组模板题。树状数组最基础的用法就是动态维护前缀和。下面出现的 lowbit add query 操作都是维护树状数组的基本函数。

 1import sys
 2
 3
 4def lowbit(x):
 5    return x & -x
 6
 7
 8def add(x, val):
 9    while x <= n:
10        tree[x] += val
11        x += lowbit(x)
12
13
14def query(x):
15    res = 0
16    while x > 0:
17        res += tree[x]
18        x -= lowbit(x)
19    return res
20
21
22data = sys.stdin.read().strip().split("\n")
23n, m = map(int, data[0].strip().split())
24nums = [0] + list(map(int, data[1].strip().split()))
25tree = [0] * (n + 1)
26for i in range(1, n + 1):
27    add(i, nums[i])
28for i in range(2, m + 2):
29    k, a, b = map(int, data[i].strip().split())
30    if k == 0:
31        print(query(b) - query(a - 1))
32    elif k == 1:
33        add(a, b)

线段树写法:

 1import sys
 2
 3
 4class Node:
 5
 6    def __init__(self, l=0, r=0, sum=0):
 7        self.l = l
 8        self.r = r
 9        self.sum = sum
10
11    def __repr__(self):
12        return f"l={self.l}, r={self.r}, sum={self.sum}"
13
14
15def pushup(u):
16    tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum
17
18
19def build(u, l, r):
20    tree[u].l = l
21    tree[u].r = r
22    if l == r:
23        tree[u] = Node(l, r, nums[r])
24    else:
25        mid = (l + r) >> 1
26        build(u << 1, l, mid)
27        build(u << 1 | 1, mid + 1, r)
28        pushup(u)
29
30
31def query(u, l, r):
32    if tree[u].l >= l and tree[u].r <= r:
33        return tree[u].sum
34    mid = (tree[u].l + tree[u].r) >> 1
35    s = 0
36    if l <= mid:
37        s += query(u << 1, l, r)
38    if r > mid:
39        s += query(u << 1 | 1, l, r)
40    return s
41
42
43def modify(u, x, v):
44    if tree[u].l == tree[u].r:
45        tree[u].sum += v
46    else:
47        mid = (tree[u].l + tree[u].r) >> 1
48        if x <= mid:
49            modify(u << 1, x, v)
50        else:
51            modify(u << 1 | 1, x, v)
52        pushup(u)
53
54
55data = sys.stdin.read().strip().split("\n")
56n, m = map(int, data[0].strip().split())
57nums = (0,) + tuple(map(int, data[1].strip().split()))
58tree = [Node(0, 0, 0) for _ in range(4 * n)]
59build(1, 1, n)
60
61for i in range(2, m + 2):
62    k, a, b = map(int, data[i].strip().split())
63    if k == 0:
64        print(query(1, a, b))
65    else:
66        modify(1, a, b)

pushup build query modify 为线段树的基本操作。如果 Python 版本为 3.7 及以上,可以使用 dataclasses.dataclass 修饰器快捷定义 Node 类。

 1import sys
 2from dataclasses import dataclass
 3
 4
 5@dataclass
 6class Node:
 7    l: int
 8    r: int
 9    sum: int
10
11
12def pushup(u):
13    tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum
14
15
16def build(u, l, r):
17    tree[u].l = l
18    tree[u].r = r
19    if l == r:
20        tree[u] = Node(l, r, nums[r])
21    else:
22        mid = (l + r) >> 1
23        build(u << 1, l, mid)
24        build(u << 1 | 1, mid + 1, r)
25        pushup(u)
26
27
28def query(u, l, r):
29    if tree[u].l >= l and tree[u].r <= r:
30        return tree[u].sum
31    mid = (tree[u].l + tree[u].r) >> 1
32    s = 0
33    if l <= mid:
34        s += query(u << 1, l, r)
35    if r > mid:
36        s += query(u << 1 | 1, l, r)
37    return s
38
39
40def modify(u, x, v):
41    if tree[u].l == tree[u].r:
42        tree[u].sum += v
43    else:
44        mid = (tree[u].l + tree[u].r) >> 1
45        if x <= mid:
46            modify(u << 1, x, v)
47        else:
48            modify(u << 1 | 1, x, v)
49        pushup(u)
50
51
52data = sys.stdin.read().strip().split("\n")
53n, m = map(int, data[0].strip().split())
54nums = (0,) + tuple(map(int, data[1].strip().split()))
55tree = [Node(0, 0, 0) for _ in range(4 * n)]
56build(1, 1, n)
57
58for i in range(2, m + 2):
59    k, a, b = map(int, data[i].strip().split())
60    if k == 0:
61        print(query(1, a, b))
62    else:
63        modify(1, a, b)

线段树使用堆存储,空间占用为 $4n$,且线段树的时空消耗大于树状数组。

1265. 数星星 | 原题链接

输入是按照纵坐标排序,纵坐标相同再按照横坐标排序。因此每当输入一个点时,其纵坐标一定是当前所有点中最大的,而要找出在其左下方的点,只需要找出当前所有点中横坐标小于他的点即可。因此这道题需要动态维护前缀和,使用树状数组解决。

 1import sys
 2from collections import defaultdict
 3
 4
 5def lowbit(x):
 6    return x & -x
 7
 8
 9def add(nums, x, val, n):
10    while x <= n:
11        nums[x] += val
12        x += lowbit(x)
13    return nums
14
15
16def query(nums, x, n):
17    res = 0
18    while x > 0:
19        res += nums[x]
20        x -= lowbit(x)
21    return res
22
23
24MAX_COUNT = 32005
25data = sys.stdin.read().strip().split("\n")
26n = int(data[0].strip())
27ans = defaultdict(int)
28tree = [0] * MAX_COUNT
29for i in range(1, n + 1):
30    x, y = map(int, data[i].strip().split())
31    x += 1
32    tree = add(tree, x, 1, MAX_COUNT)
33    count = query(tree, x, MAX_COUNT)
34    ans[count - 1] += 1
35for i in range(n):
36    print(ans[i])

注意

树状数组的下标从 $1$ 开始,而题目中的坐标从 $0$ 开始,因此需要给所有横坐标加 $1$。

1270. 数列区间最大值 | 原题链接

这道题的数据量太大了,每次查询时在整个数列中两两比较的时间复杂度为 $O(nm)$,即使用分治法比较最大值,时间复杂度也是 $O(mlogn)$。

 1import sys
 2
 3
 4def query(x, y):
 5    if x == y:
 6        return nums[y]
 7    else:
 8        mid = (x + y) // 2
 9        return max(query(x, mid), query(mid + 1, y))
10
11
12data = sys.stdin.read().strip().split("\n")
13n, m = map(int, data[0].strip().split())
14nums = (0,) + tuple(map(int, data[1].strip().split()))
15for i in range(2, m + 2):
16    x, y = map(int, data[i].strip().split())
17    print(query(x, y))

这道题可以用线段树解决,只需要修改 query 操作即可实现求区间最大值的功能。经过线段树预处理后,此算法时间复杂度为 $O(logn)$,但是 Python 写这道题的线段树仍然会超时,加上 lru_cache 还会 MLE……

 1#define _CRT_SECURE_NO_WARNINGS
 2
 3#include <algorithm>
 4#include <cstdio>
 5
 6using namespace std;
 7
 8struct Node {
 9    int l, r, max;
10} tree[400010];
11
12int nums[100010] = { 0 };
13
14void pushup(int u) {
15    tree[u].max = max(tree[u << 1].max, tree[u << 1 | 1].max);
16}
17
18void build(int u, int l, int r) {
19    tree[u].l = l;
20    tree[u].r = r;
21    if (l == r) {
22        tree[u] = { l, r, nums[r] };
23    }
24    else {
25        int mid = (l + r) >> 1;
26        build(u << 1, l, mid);
27        build(u << 1 | 1, mid + 1, r);
28        pushup(u);
29    }
30}
31
32int query(int u, int l, int r) {
33    if (tree[u].l >= l && tree[u].r <= r) {
34        return tree[u].max;
35    }
36    int mid = (tree[u].l + tree[u].r) >> 1;
37    int m = -0x7fffffff;
38    if (l <= mid) {
39        m = max(m, query(u << 1, l, r));
40    }
41    if (r > mid) {
42        m = max(m, query(u << 1 | 1, l, r));
43    }
44    return m;
45}
46
47int main() {
48    int n, m;
49    scanf("%d%d", &n, &m);
50    for (int i = 1; i <= n; i++) {
51        scanf("%d", &nums[i]);
52    }
53    build(1, 1, n);
54    for (int i = 0; i < m; i++) {
55        int x, y;
56        scanf("%d%d", &x, &y);
57        printf("%d\n", query(1, x, y));
58    }
59    return 0;
60}

1215. 小朋友排队 | 原题链接

这道题最基础的做法是冒泡排序记录每个小朋友交换的次数,随后通过交换次数计算不满意程度。这样的时间复杂度是 $O(n^2)$ 必然超时。

 1n = int(input().strip())
 2nums = list(map(int, input().strip().split()))
 3height = [[nums[i], 0] for i in range(n)]
 4
 5
 6def bubble_sort(height):
 7    for i in range(1, n):
 8        j = i - 1
 9        while j >= 0 and height[j][0] > height[j + 1][0]:
10            height[j], height[j + 1] = height[j + 1], height[j]
11            height[j][1] += 1
12            height[j + 1][1] += 1
13            j -= 1
14
15
16bubble_sort(height)
17ans = 0
18for h in height:
19    ans += h[1] * (h[1] + 1) // 2
20print(ans)

假设原数组中有 $k$ 个逆序对,那么要是想通过交换相邻两个元素的方式将其排序,则至少需要交换 $k$ 次。而对于原数列中的任意元素 $a_i$,为了将数列排序,则需要将在 $a_i$ 左侧且大于 $a_i$ 的元素交换到 $a_i$ 右侧,在 $a_i$ 右侧且小于 $a_i$ 的元素交换到 $a_i$ 左侧,因此 $a_i$ 需要交换的次数为 $a_i$ 左侧大于 $a_i$ 的元素数量与 $a_i$ 右侧小于 $a_i$ 的元素数量之和。

为了计算 $a_i$ 的交换次数,可以利用树状数组动态维护 cnt 数组前缀和,其中 cnt[i] 表示在原数组任意位置时 i 元素出现次数。在 $a_i$ 前大于 $a_i$ 的元素数量可以表示为当前已遍历的数组长度减去 cnt 在当前位置的前缀和,而在 $a_i$ 后小于 $a_i$ 的元素数量则可以通过反向遍历原数组时求前缀和得到。

例如对于序列 [1, 3, 5, 4, 2],当遍历到 $4$ 时 cnt 数组为 [1, 0, 1, 0, 1](1-indexed), 其前缀和数组为 [0, 1, 1, 2, 2, 3],在 $4$ 前且大于 $4$ 的元素只有 $5$。

 1n = int(input().strip())
 2nums = tuple(map(lambda x: int(x) + 1, input().strip().split()))
 3
 4def lowbit(x):
 5    return x & -x
 6
 7def add(tree, x, val, n):
 8    while x <= n:
 9        tree[x] += val
10        x += lowbit(x)
11    return tree
12
13def query(tree, x):
14    res = 0
15    while x > 0:
16        res += tree[x]
17        x -= lowbit(x)
18    return res
19
20count = [0] * (n + 1000005)
21tree_prefix = [0] * (n + 1000005)
22for i in range(n):
23    count[i] += i - query(tree_prefix, nums[i])
24    tree_prefix = add(tree_prefix, nums[i], 1, 1000005)
25
26tree_postfix = [0] * (n + 1000005)
27for i in range(n)[::-1]:
28    count[i] += query(tree_postfix, nums[i] - 1)
29    tree_postfix = add(tree_postfix, nums[i], 1, 1000005)
30
31s = 0
32for i in count:
33    s += i * (i + 1) // 2
34print(s)

1228. 油漆面积 | 原题链接

这道题目要求多个矩形面积并,使用到的方法是扫描线法,而扫描线需要用到线段树数据结构,因此需要用扫描线的题几乎都很难,而这道题属于扫描线模板题目。

 1from functools import total_ordering
 2
 3
 4@total_ordering
 5class Segment:
 6    def __init__(self, x=0, y1=0, y2=0, k=0):
 7        self.x = x
 8        self.y1 = y1
 9        self.y2 = y2
10        self.k = k
11
12    def __repr__(self):
13        return f"x={self.x}, y1={self.y1}, y2={self.y2}, k={self.k}"
14
15    def __eq__(self, other):
16        return self.x == other.x
17
18    def __lt__(self, other):
19        return self.x < other.x
20
21
22class Node:
23    def __init__(self, l=0, r=0, cnt=0, len=0):
24        self.l = l
25        self.r = r
26        self.cnt = cnt
27        self.len = len
28
29    def __repr__(self):
30        return f"l={self.l}, r={self.r}, cnt={self.cnt}, len={self.len}"
31
32
33def build(u, l, r):
34    tree[u].l = l
35    tree[u].r = r
36    if l == r:
37        return
38
39    mid = (l + r) >> 1
40    build(u << 1, l, mid)
41    build(u << 1 | 1, mid + 1, r)
42
43
44def pushup(u):
45    if tree[u].cnt > 0:
46        tree[u].len = tree[u].r - tree[u].l + 1
47    elif tree[u].l == tree[u].r:
48        tree[u].len = 0
49    else:
50        tree[u].len = tree[u << 1].len + tree[u << 1 | 1].len
51
52
53def modify(u, l, r, k):
54    if tree[u].l >= l and tree[u].r <= r:
55        tree[u].cnt += k
56        pushup(u)
57        return
58
59    mid = (tree[u].l + tree[u].r) >> 1
60    if l <= mid:
61        modify(u << 1, l, r, k)
62    if r > mid:
63        modify(u << 1 | 1, l, r, k)
64    pushup(u)
65
66
67segments = list()
68n = int(input().strip())
69for _ in range(n):
70    x1, y1, x2, y2 = map(int, input().strip().split())
71    segments.append(Segment(x1, y1, y2, 1))
72    segments.append(Segment(x2, y1, y2, -1))
73segments.sort()
74
75tree = [Node() for _ in range(4 * 100000 + 5)]
76build(1, 0, 100000)
77
78res = 0
79for i in range(len(segments)):
80    if i > 0:
81        res += tree[1].len * (segments[i].x - segments[i - 1].x)
82    modify(1, segments[i].y1, segments[i].y2 - 1, segments[i].k)
83
84print(res)

如果不会就暴力做,下面的代码能通过 3/10 个测试点。

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n = int(data[0].strip())
 5area = set()
 6for k in range(1, n + 1):
 7    x1, y1, x2, y2 = map(int, data[k].strip().split())
 8    for i in range(x1, x2):
 9        for j in range(y1, y2):
10            area.add((i, j))
11print(len(area))

1237. 螺旋折线 | 原题链接

我最初的想法是让一个点沿着螺旋线走到原点,具体做法是每次移动都让点走向上一个拐点,然后统计距离。

 1x, y = map(int, input().strip().split())
 2length = 0
 3while x != 0 or y != 0:
 4    if y >= -x and y <= x:
 5        length += x - y
 6        y = x
 7    if y >= x and y >= -x:
 8        length += y + x
 9        x = -y
10    if y <= -x and y >= x + 1:
11        length += y - x - 1
12        y = x + 1
13    if y <= x + 1 and y <= -x:
14        length += -x - y
15        x = -y
16
17print(length)

然而这样做会超时。数据范围是 $-10^9\leq X,Y\leq10^9$。如果测试点使用极限数据,那么这段代码需要大约 $O(4\cdot max(X,Y))$ 的时间才能解决。优化的方法是找规律。首先将点转到最近的 $y=-x,y>0$ 拐点上,不难发现对于这些拐点,其到圆点的螺旋折线距离为 $2\sum^{y-1}_{i=1}i$。

 1x, y = map(int, input().strip().split())
 2length = 0
 3if x != 0 or y != 0:
 4    if y < -x and y >= x + 1:
 5        length += y - x - 1
 6        y = x + 1
 7    if y <= x + 1 and y <= -x:
 8        length += -x - y
 9        x = -y
10    if y >= -x and y <= x:
11        length += x - y
12        y = x
13    if y >= x and y >= -x:
14        length += y + x
15        x = -y
16
17print(length + 2 * y * (2 * y - 1))

797. 差分 | 原题链接

一维差分模板题。差分序列 $b$ 与 前缀和序列 $s$ 的关系为 $b_i=s_i-s_{i-1},s_0=0$。

 1n, m = map(int, input().strip().split())
 2nums = (0,) + tuple(map(int, input().strip().split()))
 3diff = [nums[i] - nums[i - 1] for i in range(1, n + 1)] + [0]
 4for _ in range(m):
 5    l, r, c = map(int, input().strip().split())
 6    diff[l - 1] += c
 7    diff[r] -= c
 8ans = [0]
 9for i in range(n):
10    ans.append(ans[i] + diff[i])
11print(*ans[1:])

798. 差分矩阵 | 原题链接

二维差分模板题。根据二维前缀和公式 $s_{i,j}=b_{i,j}+s_{i-1,j}+s_{i,j-1}-s_{i-1,j-1}$ 可以推出 $b_{i,j}=s_{i,j}-s_{i-1,j}-s_{i,j-1}+s_{i-1,j-1}$。对 $s$ 矩阵中 $(x_1,y_1)$ 到 $(x_2,y_2)$ 区间内加上 $c$ 等同于 $b_{x_1,y_1}+c,\space b_{x_1,y_2+1}-c,\space b_{x_2+1,y_1}-c,\space b_{x_2+1,y_2+1}+c$

 1n, m, q = map(int, input().strip().split())
 2mat = [[0] * (m + 1)]
 3for _ in range(n):
 4    row = [0] + list(map(int, input().strip().split()))
 5    mat.append(row)
 6
 7diff = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
 8for i in range(n):
 9    for j in range(m):
10        diff[i][j] = mat[i + 1][j + 1] - mat[i][j + 1] - mat[i + 1][j] + mat[i][j]
11
12for _ in range(q):
13    x1, y1, x2, y2, c = map(int, input().strip().split())
14    diff[x1 - 1][y1 - 1] += c
15    diff[x2][y1 - 1] -= c
16    diff[x1 - 1][y2] -= c
17    diff[x2][y2] += c
18
19ans = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
20for i in range(1, n + 1):
21    for j in range(1, m + 1):
22        ans[i][j] = (
23            diff[i - 1][j - 1] + ans[i - 1][j] + ans[i][j - 1] - ans[i - 1][j - 1]
24        )
25for i in range(1, n + 1):
26    print(*ans[i][1:])

1232. 三体攻击 | 原题链接

这道题的思路不是很难,但是代码写起来很麻烦,而且即使给了 5s/256MB 的时空,代码也依然容易 TLE/MLE。大体思路就是通过三维差分数组记录每次攻击的结果,然后使用二分查找刚好有 0 的三维矩阵。(如果每次攻击完都查找一遍绝对会超时)。下面的代码能过 11/12 个测试点。

 1import sys
 2
 3
 4def readints():
 5    return list(map(int, sys.stdin.read().split()))
 6
 7
 8data = readints()
 9ptr = 0
10A, B, C, m = data[ptr], data[ptr + 1], data[ptr + 2], data[ptr + 3]
11ptr += 4
12d = [0] * (A * B * C)
13for i in range(A * B * C):
14    d[i] = data[ptr]
15    ptr += 1
16
17attacks = []
18for _ in range(m):
19    x1, x2, y1, y2, z1, z2, h = data[ptr : ptr + 7]
20    ptr += 7
21    attacks.append((x1, x2, y1, y2, z1, z2, h))
22
23
24def check(k):
25    diff = [[[0 for _ in range(C + 2)] for _ in range(B + 2)] for _ in range(A + 2)]
26
27    for t in range(k):
28        x1, x2, y1, y2, z1, z2, h = attacks[t]
29        diff[x1][y1][z1] += h
30        diff[x2 + 1][y1][z1] -= h
31        diff[x1][y2 + 1][z1] -= h
32        diff[x1][y1][z2 + 1] -= h
33        diff[x2 + 1][y2 + 1][z1] += h
34        diff[x2 + 1][y1][z2 + 1] += h
35        diff[x1][y2 + 1][z2 + 1] += h
36        diff[x2 + 1][y2 + 1][z2 + 1] -= h
37
38    for i in range(1, A + 1):
39        for j in range(1, B + 1):
40            for k in range(1, C + 1):
41                diff[i][j][k] += (
42                    diff[i - 1][j][k]
43                    + diff[i][j - 1][k]
44                    + diff[i][j][k - 1]
45                    - diff[i - 1][j - 1][k]
46                    - diff[i - 1][j][k - 1]
47                    - diff[i][j - 1][k - 1]
48                    + diff[i - 1][j - 1][k - 1]
49                )
50                idx = (i - 1) * B * C + (j - 1) * C + (k - 1)
51                if diff[i][j][k] > d[idx]:
52                    return True
53    return False
54
55
56left = 1
57right = m
58ans = m
59
60while left <= right:
61    mid = (left + right) // 2
62    if check(mid):
63        ans = mid
64        right = mid - 1
65    else:
66        left = mid + 1
67
68print(ans)

用 C++ 可以 AC:

 1#define _CRT_SECURE_NO_WARNINGS
 2
 3#include <cstdio>
 4#include <cstring>
 5
 6using namespace std;
 7
 8const int N = 10000000;
 9int s[N] = { 0 }, b[N] = { 0 }, w[N] = { 0 }, sc[N] = { 0 };
10int A, B, C, m;
11
12int get(int i, int j, int k) {
13    return ((i - 1) * B + (j - 1)) * C + (k - 1);
14}
15
16bool check(int m) {
17    memset(b, 0, sizeof(int) * N);
18    int x1, x2, y1, y2, z1, z2, h;
19    for (int i = 0; i < m * 7; i += 7) {
20        x1 = w[i], x2 = w[i + 1];
21        y1 = w[i + 2], y2 = w[i + 3];
22        z1 = w[i + 4], z2 = w[i + 5];
23        h = w[i + 6];
24
25        b[get(x1, y1, z1)] += h;
26        b[get(x2 + 1, y1, z1)] -= h;
27        b[get(x1, y2 + 1, z1)] -= h;
28        b[get(x1, y1, z2 + 1)] -= h;
29        b[get(x2 + 1, y2 + 1, z1)] += h;
30        b[get(x2 + 1, y1, z2 + 1)] += h;
31        b[get(x1, y2 + 1, z2 + 1)] += h;
32        b[get(x2 + 1, y2 + 1, z2 + 1)] -= h;
33    }
34
35    for (int i = 1; i <= A; i++) {
36        for (int j = 1; j <= B; j++) {
37            for (int k = 1; k <= C; k++) {
38                b[get(i, j, k)] += b[get(i - 1, j, k)] + b[get(i, j - 1, k)] + b[get(i, j, k - 1)] - b[get(i - 1, j - 1, k)] - b[get(i - 1, j, k - 1)] - b[get(i, j - 1, k - 1)] + b[get(i - 1, j - 1, k - 1)];
39                if (b[get(i, j, k)] > s[get(i, j, k)]) {
40                    return true;
41                }
42            }
43        }
44    }
45    return false;
46}
47
48int main() {
49    scanf("%d %d %d %d", &A, &B, &C, &m);
50    for (int i = 1; i <= A; i++) {
51        for (int j = 1; j <= B; j++) {
52            for (int k = 1; k <= C; k++) {
53                scanf("%d", &s[get(i, j, k)]);
54            }
55        }
56    }
57    for (int i = 0; i < m * 7; i += 7) {
58        scanf("%d %d %d %d %d %d %d", &w[i], &w[i + 1], &w[i + 2], &w[i + 3], &w[i + 4], &w[i + 5], &w[i + 6]);
59    }
60
61    int left = 1, right = m, ans = m;
62    while (left <= right) {
63        int mid = (left + right) >> 1;
64        if (check(mid)) {
65            ans = mid;
66            right = mid - 1;
67        }
68        else {
69            left = mid + 1;
70        }
71    }
72    printf("%d\n", ans);
73    return 0;
74}

因为 $A\time B\time C\leq 10^6$,如果将每一个维度都开 $10^6$ 就会爆内存,因此用一个 get 函数做三维到一维的压缩,这样只用开一个 $10^6$ 的一维数组就可以。因为需要保留一些“余量”,因此实际操作的时候开一个 $10^7$ 的数组。

相关系列文章