2025春季算法练习——Week 3
文章目录
这些题看上去较简单,但是实操起来还有很多细节需要注意。
题目来自 AcWing
830. 单调栈 | 原题链接
遍历原数组中的每个数,然后作如下操作:
- 将栈顶元素出栈直到栈顶元素小于列表当前元素
- 将栈顶元素添加到答案列表中
- 将列表当前元素入栈
在栈中,如果一个元素小于前一个元素,那么这个元素后面所有元素的“左边第一个比它小的数”就不可能是当前这个元素前面的元素。
1from collections import deque
2
3n = int(input().strip())
4nums = tuple(map(int, input().strip().split()))
5stack = deque()
6ans = list()
7for num in nums:
8 while stack and stack[-1] >= num:
9 stack.pop()
10 ans.append(-1 if not stack else stack[-1])
11 stack.append(num)
12
13print(*ans)
154. 滑动窗口 | 原题链接
这道题像是上一道题的进阶版,区别是这道题只能在一个区间范围内找最大/最小值。解决这道题的流程如下:
- 将已经离开窗口的元素从队首弹出
- 确保队尾元素大于/小于列表当前元素,否则将元素从队尾出队
- 将列表当前元素入队
- 队首元素即为当前窗口的最大/最小值,将这个值添加到答案列表中
为了方便判断某个元素是否已经离开窗口,队列中存储的是元素的下标。
1from collections import deque
2
3n, k = map(int, input().strip().split())
4k -= 1
5nums = tuple(map(int, input().strip().split()))
6minans = list()
7maxans = list()
8minq = deque()
9maxq = deque()
10
11for j in range(n):
12 if minq and minq[0] < j - k:
13 minq.popleft()
14 if maxq and maxq[0] < j - k:
15 maxq.popleft()
16
17 while minq and nums[minq[-1]] >= nums[j]:
18 minq.pop()
19 minq.append(j)
20 while maxq and nums[maxq[-1]] <= nums[j]:
21 maxq.pop()
22 maxq.append(j)
23
24 if j - k >= 0:
25 minans.append(nums[minq[0]])
26 maxans.append(nums[maxq[0]])
27
28print(*minans)
29print(*maxans)
831. KMP字符串 | 原题链接
KMP 算法模板题,核心在于生成 next
数组以及匹配。
1def build_next(patt):
2 nxt = [0] * len(patt)
3 j = 0
4 for i in range(1, len(patt)):
5 while j > 0 and patt[i] != patt[j]:
6 j = nxt[j - 1]
7 if patt[i] == patt[j]:
8 j += 1
9 nxt[i] = j
10 return nxt
11
12
13def kmp_match(patt, string, nxt):
14 i = j = 0
15 lenp = len(patt)
16 while i < len(string):
17 if patt[j] == string[i]:
18 i += 1
19 j += 1
20 elif j > 0:
21 j = nxt[j - 1]
22 else:
23 i += 1
24
25 if j == lenp:
26 yield i - j
27 j = nxt[j - 1]
28
29
30N = int(input().strip())
31P = input().strip()
32M = int(input().strip())
33S = input().strip()
34nxt = build_next(P)
35ans = [i for i in kmp_match(P, S, nxt)]
36print(*ans)
835. Trie字符串统计 | 原题链接
Trie 树模板题。
1import sys
2
3letters = "abcdefghijklmnopqrstuvwxyz"
4
5
6class TrieNode:
7
8 def __init__(self):
9 self.key = 0
10 self.children = {ch: None for ch in letters}
11
12
13data = sys.stdin.read().strip().split("\n")
14trie = TrieNode()
15n = int(data[0])
16for i in range(1, n + 1):
17 q, s = data[i].strip().split()
18 if q == "I":
19 node = trie
20 for c in s:
21 if node.children[c] is None:
22 node.children[c] = TrieNode()
23 node = node.children[c]
24 node.key += 1
25 else:
26 node = trie
27 available = True
28 for c in s:
29 if node.children[c] is None:
30 available = False
31 break
32 node = node.children[c]
33 if node and node.key and available:
34 print(node.key)
35 else:
36 print(0)
143. 最大异或对 | 原题链接
这道题需要用 Trie 树和贪心来解。如果我们希望两个数异或的结果尽可能大,那么这两个数二进制里越多位不同,异或的结果越大,而且高位的权重要低于低位。因此我们将所有的数的二进制从高位到低位存储进一个 Trie 树,并且每一个叶子节点到根节点的距离都是 31。每当读取一个数时,将这个数的二进制存储进树中,同时查找与其不同位尽可能多的二进制数,这两数异或,得到的就是当前处理的数与前面出现过的数的最大异或值。
1class TrieNode:
2
3 def __init__(self):
4 self.children = [None, None]
5
6
7n = int(input().strip())
8nums = list(map(int, input().strip().split()))
9trie = TrieNode()
10max_xor = 0
11for num in nums:
12 node = trie
13 xor_node = trie
14 xor_res = 0
15 for i in range(32)[::-1]:
16 bit = (num >> i) & 1
17 if node.children[bit] is None:
18 node.children[bit] = TrieNode()
19 node = node.children[bit]
20
21 if xor_node.children[1 - bit] is not None:
22 xor_res |= (1 - bit) << i
23 xor_node = xor_node.children[1 - bit]
24 elif xor_node.children[bit] is not None:
25 xor_res |= bit << i
26 xor_node = xor_node.children[bit]
27
28 xor_res ^= num
29 if xor_res > max_xor:
30 max_xor = xor_res
31
32print(max_xor)
Python 中的按位取反运算符 ~
对有符号数做操作,例如 ~0b1101 = -0b1110
,因此要用 1 - bit
对 bit
取反。
836. 合并集合 | 原题链接
由于 Python 有递归深度限制,不用递归硬找祖宗节点的写法可以过 10/11 个测试点。
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n, m = map(int, data[0].strip().split())
5pre = [i for i in range(n + 1)]
6
7
8def merge(n1, n2):
9 rank1 = rank2 = 0
10 while pre[n1] != n1:
11 n1 = pre[n1]
12 rank1 += 1
13 while pre[n2] != n2:
14 n2 = pre[n2]
15 rank2 += 1
16
17 if n1 == n2:
18 return
19 if rank1 > rank2:
20 pre[n2] = n1
21 else:
22 pre[n1] = n2
23
24
25def query(n1, n2):
26 while pre[n1] != n1:
27 n1 = pre[n1]
28 while pre[n2] != n2:
29 n2 = pre[n2]
30
31 return n1 == n2
32
33
34for i in range(1, m + 1):
35 q, n1, n2 = data[i].strip().split()
36 n1 = int(n1)
37 n2 = int(n2)
38
39 if q == "M":
40 merge(n1, n2)
41 else:
42 print("Yes" if query(n1, n2) else "No")
Java 里可以用 find
函数递归解决:
1import java.util.Scanner;
2
3public class Main {
4 public static int[] pre;
5
6 public static void main(String[] args) {
7 Scanner sc = new Scanner(System.in);
8 int n = sc.nextInt();
9 int m = sc.nextInt();
10 pre = new int[n + 1];
11 for (int i = 1; i <= n; i++) {
12 pre[i] = i;
13 }
14 sc.nextLine();
15
16 while (m-- > 0) {
17 String[] line = sc.nextLine().trim().split("\\s+");
18 String q = line[0].trim();
19 int n1 = Integer.parseInt(line[1].trim());
20 int n2 = Integer.parseInt(line[2].trim());
21
22 if (q.equals("M")) {
23 pre[find(n1)] = find(n2);
24 } else {
25 System.out.println(find(n1) == find(n2) ? "Yes" : "No");
26 }
27 }
28 }
29
30 public static int find(int n) {
31 if (pre[n] != n) {
32 pre[n] = find(pre[n]);
33 }
34 return pre[n];
35 }
36}
对应的 Python 写法(不能 AC):
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n, m = map(int, data[0].strip().split())
5pre = [i for i in range(n + 1)]
6
7
8def find(n):
9 if pre[n] != n:
10 pre[n] = find(pre[n])
11 return pre[n]
12
13
14for i in range(1, m + 1):
15 q, n1, n2 = data[i].strip().split()
16 n1 = int(n1)
17 n2 = int(n2)
18
19 if q == "M":
20 pre[find(n1)] = find(n2)
21 else:
22 print("Yes" if find(n1) == find(n2) else "No")
837. 连通块中点的数量 | 原题链接
这道题比上一道题多的部分是查找连通块数量。可以用 size
记录每一个节点的连通块数量,其中 size[i]
表示以 i
为根节点的连通块数量。
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n, m = map(int, data[0].strip().split())
5pre = [i for i in range(n + 1)]
6size = [1 for _ in range(n + 1)]
7
8
9def connect(n1, n2):
10 rank1 = rank2 = 0
11 while pre[n1] != n1:
12 n1 = pre[n1]
13 rank1 += 1
14 while pre[n2] != n2:
15 n2 = pre[n2]
16 rank2 += 1
17
18 if rank1 > rank2:
19 pre[n2] = n1
20 size[n1] += size[n2]
21 else:
22 pre[n1] = n2
23 size[n2] += size[n1]
24
25
26def query1(n1, n2):
27 while pre[n1] != n1:
28 n1 = pre[n1]
29 while pre[n2] != n2:
30 n2 = pre[n2]
31 print("Yes" if n1 == n2 else "No")
32
33
34def query2(n):
35 while pre[n] != n:
36 n = pre[n]
37 print(size[n])
38
39
40functions = {"C": connect, "Q1": query1, "Q2": query2}
41for i in range(1, m + 1):
42 q = data[i].strip().split()
43 functions[q[0]](*map(int, q[1:]))
没想到 Python 可以 AC。
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n, m = map(int, data[0].strip().split())
5pre = [i for i in range(n + 1)]
6size = [1 for _ in range(n + 1)]
7
8
9def find(n):
10 if pre[n] != n:
11 pre[n] = find(pre[n])
12 return pre[n]
13
14
15def connect(n1, n2):
16 root1, root2 = find(n1), find(n2)
17 pre[root1] = root2
18 if root1 != root2:
19 size[root2] += size[root1]
20
21
22def query1(n1, n2):
23 print("Yes" if find(n1) == find(n2) else "No")
24
25
26def query2(n):
27 print(size[find(n)])
28
29
30functions = {"C": connect, "Q1": query1, "Q2": query2}
31for i in range(1, m + 1):
32 q = data[i].strip().split()
33 functions[q[0]](*map(int, q[1:]))
递归写法也能 AC。
240. 食物链 | 原题链接
并查集的高级用法:带权并查集。
已知总共只有三种生物,可以定义并查集中每一个节点到根节点的距离模 3 得到的值为某生物与根节点生物的关系,如果这个值为 0 则表示当前节点物种与根节点物种同类;1 表示当前节点物种吃根节点物种;2 表示当前节点物种被根节点物种吃。用 dist
数组存储每个节点与根节点之间的距离关系。两物种间的关系由他们到根节点的距离之差决定。
1import sys
2
3data = sys.stdin.read().strip().split("\n")
4n, k = map(int, data[0].strip().split())
5pre = [i for i in range(n + 1)]
6dist = [0 for _ in range(n + 1)]
7
8
9def find(n):
10 if pre[n] != n:
11 t = find(pre[n])
12 dist[n] += dist[pre[n]]
13 pre[n] = t
14 return pre[n]
15
16
17count = 0
18for i in range(1, k + 1):
19 d, x, y = map(int, data[i].strip().split())
20
21 if x > n or y > n:
22 count += 1
23 continue
24
25 px, py = find(x), find(y)
26 if d == 1:
27 if px == py and (dist[x] - dist[y]) % 3 != 0:
28 count += 1
29 continue
30 elif px != py:
31 pre[px] = py
32 dist[px] = dist[y] - dist[x]
33 else:
34 if px == py and ((dist[x] - dist[y] - 1) % 3 != 0):
35 count += 1
36 continue
37 elif px != py:
38 pre[px] = py
39 dist[px] = dist[y] + 1 - dist[x]
40
41print(count)
838. 堆排序 | 原题链接
类似于没必要的排序2。
1import heapq
2
3n, m = map(int, input().strip().split())
4nums = list(map(int, input().strip().split()))
5
6heap = list()
7for num in nums:
8 heapq.heappush(heap, num)
9
10res = list()
11for _ in range(m):
12 res.append(heapq.heappop(heap))
13print(*res)
839. 模拟堆 | 原题链接
这道题要手搓堆,需要单独记录元素插入的顺序,需要有插入顺序到元素位置的映射,也要有元素位置到插入顺序的映射。
1import sys
2
3
4class MinHeap:
5 def __init__(self):
6 self.heap = [0]
7 self.position = {}
8
9 def _heapify_up(self, idx):
10 while idx > 1 and self.heap[idx] < self.heap[idx // 2]:
11 self._swap(idx, idx // 2)
12 idx //= 2
13
14 def _heapify_down(self, idx):
15 n = len(self.heap) - 1
16 while idx * 2 <= n:
17 smallest = idx * 2
18 if smallest + 1 <= n and self.heap[smallest + 1] < self.heap[smallest]:
19 smallest += 1
20 if self.heap[idx] <= self.heap[smallest]:
21 break
22 self._swap(idx, smallest)
23 idx = smallest
24
25 def _swap(self, i, j):
26 self.position[self.heap[i][1]] = j
27 self.position[self.heap[j][1]] = i
28 self.heap[i], self.heap[j] = self.heap[j], self.heap[i]
29
30 def insert(self, val, id):
31 self.heap.append((val, id))
32 self.position[id] = len(self.heap) - 1
33 self._heapify_up(len(self.heap) - 1)
34
35 def get_min(self):
36 return self.heap[1][0]
37
38 def delete_min(self):
39 min_element_id = self.heap[1][1]
40 self._swap(1, len(self.heap) - 1)
41 self.heap.pop()
42 del self.position[min_element_id]
43 if len(self.heap) > 1:
44 self._heapify_down(1)
45
46 def delete_kth(self, id):
47 idx = self.position[id]
48 self._swap(idx, len(self.heap) - 1)
49 self.heap.pop()
50 del self.position[id]
51 if idx < len(self.heap):
52 self._heapify_up(idx)
53 self._heapify_down(idx)
54
55 def modify_kth(self, id, new_val):
56 idx = self.position[id]
57 old_val = self.heap[idx][0]
58 self.heap[idx] = (new_val, id)
59 if new_val < old_val:
60 self._heapify_up(idx)
61 else:
62 self._heapify_down(idx)
63
64
65def solve(operations):
66 heap = MinHeap()
67 inserted = {}
68 counter = 0
69
70 for op in operations:
71 command = op[0]
72
73 if command == "I":
74 x = int(op[1])
75 counter += 1
76 inserted[counter] = x
77 heap.insert(x, counter)
78
79 elif command == "PM":
80 print(heap.get_min())
81
82 elif command == "DM":
83 heap.delete_min()
84
85 elif command == "D":
86 k = int(op[1])
87 if k in inserted:
88 heap.delete_kth(k)
89
90 elif command == "C":
91 k = int(op[1])
92 new_val = int(op[2])
93 if k in inserted:
94 inserted[k] = new_val
95 heap.modify_kth(k, new_val)
96
97
98data = sys.stdin.read().strip().splitlines()
99n = int(data[0])
100operations = [line.split() for line in data[1:]]
101solve(operations)
840. 模拟散列表 | 原题链接
1import sys
2
3N = int(1e5)
4hashmap = [None] * N
5
6
7def insert(x):
8 pos = x % N
9 while hashmap[pos] is not None:
10 pos = (pos + 1) % N
11 hashmap[pos] = x
12
13
14def query(x):
15 pos = x % N
16 while hashmap[pos] is not None:
17 if hashmap[pos] == x:
18 print("Yes")
19 return
20 pos = (pos + 1) % N
21 print("No")
22
23
24functions = {"I": insert, "Q": query}
25data = sys.stdin.read().strip().splitlines()
26n = int(data[0])
27for i in range(1, n + 1):
28 line = data[i].strip().split()
29 functions[line[0]](int(line[1]))
841. 字符串哈希 | 原题链接
暴力 TLE(11/13):
1import sys
2
3data = sys.stdin.read().strip().splitlines()
4_, m = map(int, data[0].strip().split())
5s = data[1].strip()
6for i in range(2, m + 2):
7 l1, r1, l2, r2 = map(int, data[i].strip().split())
8 print("Yes" if s[l1 - 1 : r1] == s[l2 - 1 : r2] else "No")
字符串哈希的基本思路是将字符串转化为一个整数哈希值,预处理计算所有前缀的哈希值,然后通过哈希值的差来快速计算任意子串的哈希值。具体做法是将字符串转化为 $P$ 进制数,同时为了防止数据太大,存储时要对 $Q$ 取模。按照经验,当 $P=131$ 或 $P=13331$ 且 $Q=2^{64}$ 时发生哈希冲突的概率极低。
使用 weight
存储 $P$ 进制中每一位的位权,即 $weight_i=weight_{i-1}\cdot P,weight_0=1$;hash_table
存储字符串前缀的哈希,有 hash_table[i] = hash_table[i - 1] * P + ord(string[i - 1])
。计算某一段字符串哈希的方法是 hash = hash_table[r] - hash_table[l - 1] * weight[r - l + 1]
。
以十进制 $1435$ 为例,weight = [1, 10, 100, 1000], hash_table = [1, 14, 145, 1453]
,这个数字第 $3$、$4$ 位的数字可以表示为 $35=1453-14\times10^2$。
另一个需要注意的问题是,不能有字符映射到 $0$。假设字符 a
映射到 $0$,那么即使 a
和 aaa
不同,他们的哈希值都是 $0$。
1import sys
2
3MOD = 2**64
4BASE = 131
5
6data = sys.stdin.read().strip().splitlines()
7n, m = map(int, data[0].strip().split())
8s = data[1].strip()
9
10hash_table = [0] * (n + 1)
11weight = [1] * (n + 1)
12for i in range(1, n + 1):
13 hash_table[i] = (hash_table[i - 1] * BASE + ord(s[i - 1])) % MOD
14 weight[i] = (weight[i - 1] * BASE) % MOD
15
16for i in range(2, m + 2):
17 l1, r1, l2, r2 = map(int, data[i].strip().split())
18 hash1 = (hash_table[r1] - hash_table[l1 - 1] * weight[r1 - l1 + 1] % MOD) % MOD
19 hash2 = (hash_table[r2] - hash_table[l2 - 1] * weight[r2 - l2 + 1] % MOD) % MOD
20 print("Yes" if hash1 == hash2 else "No")