2025春季算法练习——Week 3

文章目录

这些题看上去较简单,但是实操起来还有很多细节需要注意。

题目来自 AcWing


830. 单调栈 | 原题链接

遍历原数组中的每个数,然后作如下操作:

  • 将栈顶元素出栈直到栈顶元素小于列表当前元素
  • 将栈顶元素添加到答案列表中
  • 将列表当前元素入栈

在栈中,如果一个元素小于前一个元素,那么这个元素后面所有元素的“左边第一个比它小的数”就不可能是当前这个元素前面的元素。

 1from collections import deque
 2
 3n = int(input().strip())
 4nums = tuple(map(int, input().strip().split()))
 5stack = deque()
 6ans = list()
 7for num in nums:
 8    while stack and stack[-1] >= num:
 9        stack.pop()
10    ans.append(-1 if not stack else stack[-1])
11    stack.append(num)
12
13print(*ans)

154. 滑动窗口 | 原题链接

这道题像是上一道题的进阶版,区别是这道题只能在一个区间范围内找最大/最小值。解决这道题的流程如下:

  • 将已经离开窗口的元素从队首弹出
  • 确保队尾元素大于/小于列表当前元素,否则将元素从队尾出队
  • 将列表当前元素入队
  • 队首元素即为当前窗口的最大/最小值,将这个值添加到答案列表中

为了方便判断某个元素是否已经离开窗口,队列中存储的是元素的下标。

 1from collections import deque
 2
 3n, k = map(int, input().strip().split())
 4k -= 1
 5nums = tuple(map(int, input().strip().split()))
 6minans = list()
 7maxans = list()
 8minq = deque()
 9maxq = deque()
10
11for j in range(n):
12    if minq and minq[0] < j - k:
13        minq.popleft()
14    if maxq and maxq[0] < j - k:
15        maxq.popleft()
16
17    while minq and nums[minq[-1]] >= nums[j]:
18        minq.pop()
19    minq.append(j)
20    while maxq and nums[maxq[-1]] <= nums[j]:
21        maxq.pop()
22    maxq.append(j)
23
24    if j - k >= 0:
25        minans.append(nums[minq[0]])
26        maxans.append(nums[maxq[0]])
27
28print(*minans)
29print(*maxans)

831. KMP字符串 | 原题链接

KMP 算法模板题,核心在于生成 next 数组以及匹配。

 1def build_next(patt):
 2    nxt = [0] * len(patt)
 3    j = 0
 4    for i in range(1, len(patt)):
 5        while j > 0 and patt[i] != patt[j]:
 6            j = nxt[j - 1]
 7        if patt[i] == patt[j]:
 8            j += 1
 9        nxt[i] = j
10    return nxt
11
12
13def kmp_match(patt, string, nxt):
14    i = j = 0
15    lenp = len(patt)
16    while i < len(string):
17        if patt[j] == string[i]:
18            i += 1
19            j += 1
20        elif j > 0:
21            j = nxt[j - 1]
22        else:
23            i += 1
24
25        if j == lenp:
26            yield i - j
27            j = nxt[j - 1]
28
29
30N = int(input().strip())
31P = input().strip()
32M = int(input().strip())
33S = input().strip()
34nxt = build_next(P)
35ans = [i for i in kmp_match(P, S, nxt)]
36print(*ans)

835. Trie字符串统计 | 原题链接

Trie 树模板题。

 1import sys
 2
 3letters = "abcdefghijklmnopqrstuvwxyz"
 4
 5
 6class TrieNode:
 7
 8    def __init__(self):
 9        self.key = 0
10        self.children = {ch: None for ch in letters}
11
12
13data = sys.stdin.read().strip().split("\n")
14trie = TrieNode()
15n = int(data[0])
16for i in range(1, n + 1):
17    q, s = data[i].strip().split()
18    if q == "I":
19        node = trie
20        for c in s:
21            if node.children[c] is None:
22                node.children[c] = TrieNode()
23            node = node.children[c]
24        node.key += 1
25    else:
26        node = trie
27        available = True
28        for c in s:
29            if node.children[c] is None:
30                available = False
31                break
32            node = node.children[c]
33        if node and node.key and available:
34            print(node.key)
35        else:
36            print(0)

143. 最大异或对 | 原题链接

这道题需要用 Trie 树和贪心来解。如果我们希望两个数异或的结果尽可能大,那么这两个数二进制里越多位不同,异或的结果越大,而且高位的权重要低于低位。因此我们将所有的数的二进制从高位到低位存储进一个 Trie 树,并且每一个叶子节点到根节点的距离都是 31。每当读取一个数时,将这个数的二进制存储进树中,同时查找与其不同位尽可能多的二进制数,这两数异或,得到的就是当前处理的数与前面出现过的数的最大异或值。

 1class TrieNode:
 2
 3    def __init__(self):
 4        self.children = [None, None]
 5
 6
 7n = int(input().strip())
 8nums = list(map(int, input().strip().split()))
 9trie = TrieNode()
10max_xor = 0
11for num in nums:
12    node = trie
13    xor_node = trie
14    xor_res = 0
15    for i in range(32)[::-1]:
16        bit = (num >> i) & 1
17        if node.children[bit] is None:
18            node.children[bit] = TrieNode()
19        node = node.children[bit]
20
21        if xor_node.children[1 - bit] is not None:
22            xor_res |= (1 - bit) << i
23            xor_node = xor_node.children[1 - bit]
24        elif xor_node.children[bit] is not None:
25            xor_res |= bit << i
26            xor_node = xor_node.children[bit]
27
28    xor_res ^= num
29    if xor_res > max_xor:
30        max_xor = xor_res
31
32print(max_xor)

Python 中的按位取反运算符 ~ 对有符号数做操作,例如 ~0b1101 = -0b1110,因此要用 1 - bitbit 取反。

836. 合并集合 | 原题链接

由于 Python 有递归深度限制,不用递归硬找祖宗节点的写法可以过 10/11 个测试点。

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n, m = map(int, data[0].strip().split())
 5pre = [i for i in range(n + 1)]
 6
 7
 8def merge(n1, n2):
 9    rank1 = rank2 = 0
10    while pre[n1] != n1:
11        n1 = pre[n1]
12        rank1 += 1
13    while pre[n2] != n2:
14        n2 = pre[n2]
15        rank2 += 1
16
17    if n1 == n2:
18        return
19    if rank1 > rank2:
20        pre[n2] = n1
21    else:
22        pre[n1] = n2
23
24
25def query(n1, n2):
26    while pre[n1] != n1:
27        n1 = pre[n1]
28    while pre[n2] != n2:
29        n2 = pre[n2]
30
31    return n1 == n2
32
33
34for i in range(1, m + 1):
35    q, n1, n2 = data[i].strip().split()
36    n1 = int(n1)
37    n2 = int(n2)
38
39    if q == "M":
40        merge(n1, n2)
41    else:
42        print("Yes" if query(n1, n2) else "No")

Java 里可以用 find 函数递归解决:

 1import java.util.Scanner;
 2
 3public class Main {
 4    public static int[] pre;
 5
 6    public static void main(String[] args) {
 7        Scanner sc = new Scanner(System.in);
 8        int n = sc.nextInt();
 9        int m = sc.nextInt();
10        pre = new int[n + 1];
11        for (int i = 1; i <= n; i++) {
12            pre[i] = i;
13        }
14        sc.nextLine();
15
16        while (m-- > 0) {
17            String[] line = sc.nextLine().trim().split("\\s+");
18            String q = line[0].trim();
19            int n1 = Integer.parseInt(line[1].trim());
20            int n2 = Integer.parseInt(line[2].trim());
21
22            if (q.equals("M")) {
23                pre[find(n1)] = find(n2);
24            } else {
25                System.out.println(find(n1) == find(n2) ? "Yes" : "No");
26            }
27        }
28    }
29
30    public static int find(int n) {
31        if (pre[n] != n) {
32            pre[n] = find(pre[n]);
33        }
34        return pre[n];
35    }
36}

对应的 Python 写法(不能 AC):

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n, m = map(int, data[0].strip().split())
 5pre = [i for i in range(n + 1)]
 6
 7
 8def find(n):
 9    if pre[n] != n:
10        pre[n] = find(pre[n])
11    return pre[n]
12
13
14for i in range(1, m + 1):
15    q, n1, n2 = data[i].strip().split()
16    n1 = int(n1)
17    n2 = int(n2)
18
19    if q == "M":
20        pre[find(n1)] = find(n2)
21    else:
22        print("Yes" if find(n1) == find(n2) else "No")

837. 连通块中点的数量 | 原题链接

这道题比上一道题多的部分是查找连通块数量。可以用 size 记录每一个节点的连通块数量,其中 size[i] 表示以 i 为根节点的连通块数量。

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n, m = map(int, data[0].strip().split())
 5pre = [i for i in range(n + 1)]
 6size = [1 for _ in range(n + 1)]
 7
 8
 9def connect(n1, n2):
10    rank1 = rank2 = 0
11    while pre[n1] != n1:
12        n1 = pre[n1]
13        rank1 += 1
14    while pre[n2] != n2:
15        n2 = pre[n2]
16        rank2 += 1
17
18    if rank1 > rank2:
19        pre[n2] = n1
20        size[n1] += size[n2]
21    else:
22        pre[n1] = n2
23        size[n2] += size[n1]
24
25
26def query1(n1, n2):
27    while pre[n1] != n1:
28        n1 = pre[n1]
29    while pre[n2] != n2:
30        n2 = pre[n2]
31    print("Yes" if n1 == n2 else "No")
32
33
34def query2(n):
35    while pre[n] != n:
36        n = pre[n]
37    print(size[n])
38
39
40functions = {"C": connect, "Q1": query1, "Q2": query2}
41for i in range(1, m + 1):
42    q = data[i].strip().split()
43    functions[q[0]](*map(int, q[1:]))

没想到 Python 可以 AC。

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n, m = map(int, data[0].strip().split())
 5pre = [i for i in range(n + 1)]
 6size = [1 for _ in range(n + 1)]
 7
 8
 9def find(n):
10    if pre[n] != n:
11        pre[n] = find(pre[n])
12    return pre[n]
13
14
15def connect(n1, n2):
16    root1, root2 = find(n1), find(n2)
17    pre[root1] = root2
18    if root1 != root2:
19        size[root2] += size[root1]
20
21
22def query1(n1, n2):
23    print("Yes" if find(n1) == find(n2) else "No")
24
25
26def query2(n):
27    print(size[find(n)])
28
29
30functions = {"C": connect, "Q1": query1, "Q2": query2}
31for i in range(1, m + 1):
32    q = data[i].strip().split()
33    functions[q[0]](*map(int, q[1:]))

递归写法也能 AC。

240. 食物链 | 原题链接

并查集的高级用法:带权并查集。

已知总共只有三种生物,可以定义并查集中每一个节点到根节点的距离模 3 得到的值为某生物与根节点生物的关系,如果这个值为 0 则表示当前节点物种与根节点物种同类;1 表示当前节点物种吃根节点物种;2 表示当前节点物种被根节点物种吃。用 dist 数组存储每个节点与根节点之间的距离关系。两物种间的关系由他们到根节点的距离之差决定。

 1import sys
 2
 3data = sys.stdin.read().strip().split("\n")
 4n, k = map(int, data[0].strip().split())
 5pre = [i for i in range(n + 1)]
 6dist = [0 for _ in range(n + 1)]
 7
 8
 9def find(n):
10    if pre[n] != n:
11        t = find(pre[n])
12        dist[n] += dist[pre[n]]
13        pre[n] = t
14    return pre[n]
15
16
17count = 0
18for i in range(1, k + 1):
19    d, x, y = map(int, data[i].strip().split())
20
21    if x > n or y > n:
22        count += 1
23        continue
24
25    px, py = find(x), find(y)
26    if d == 1:
27        if px == py and (dist[x] - dist[y]) % 3 != 0:
28            count += 1
29            continue
30        elif px != py:
31            pre[px] = py
32            dist[px] = dist[y] - dist[x]
33    else:
34        if px == py and ((dist[x] - dist[y] - 1) % 3 != 0):
35            count += 1
36            continue
37        elif px != py:
38            pre[px] = py
39            dist[px] = dist[y] + 1 - dist[x]
40
41print(count)

838. 堆排序 | 原题链接

类似于没必要的排序2

 1import heapq
 2
 3n, m = map(int, input().strip().split())
 4nums = list(map(int, input().strip().split()))
 5
 6heap = list()
 7for num in nums:
 8    heapq.heappush(heap, num)
 9
10res = list()
11for _ in range(m):
12    res.append(heapq.heappop(heap))
13print(*res)

839. 模拟堆 | 原题链接

这道题要手搓堆,需要单独记录元素插入的顺序,需要有插入顺序到元素位置的映射,也要有元素位置到插入顺序的映射。

  1import sys
  2
  3
  4class MinHeap:
  5    def __init__(self):
  6        self.heap = [0]
  7        self.position = {}
  8
  9    def _heapify_up(self, idx):
 10        while idx > 1 and self.heap[idx] < self.heap[idx // 2]:
 11            self._swap(idx, idx // 2)
 12            idx //= 2
 13
 14    def _heapify_down(self, idx):
 15        n = len(self.heap) - 1
 16        while idx * 2 <= n:
 17            smallest = idx * 2
 18            if smallest + 1 <= n and self.heap[smallest + 1] < self.heap[smallest]:
 19                smallest += 1
 20            if self.heap[idx] <= self.heap[smallest]:
 21                break
 22            self._swap(idx, smallest)
 23            idx = smallest
 24
 25    def _swap(self, i, j):
 26        self.position[self.heap[i][1]] = j
 27        self.position[self.heap[j][1]] = i
 28        self.heap[i], self.heap[j] = self.heap[j], self.heap[i]
 29
 30    def insert(self, val, id):
 31        self.heap.append((val, id))
 32        self.position[id] = len(self.heap) - 1
 33        self._heapify_up(len(self.heap) - 1)
 34
 35    def get_min(self):
 36        return self.heap[1][0]
 37
 38    def delete_min(self):
 39        min_element_id = self.heap[1][1]
 40        self._swap(1, len(self.heap) - 1)
 41        self.heap.pop()
 42        del self.position[min_element_id]
 43        if len(self.heap) > 1:
 44            self._heapify_down(1)
 45
 46    def delete_kth(self, id):
 47        idx = self.position[id]
 48        self._swap(idx, len(self.heap) - 1)
 49        self.heap.pop()
 50        del self.position[id]
 51        if idx < len(self.heap):
 52            self._heapify_up(idx)
 53            self._heapify_down(idx)
 54
 55    def modify_kth(self, id, new_val):
 56        idx = self.position[id]
 57        old_val = self.heap[idx][0]
 58        self.heap[idx] = (new_val, id)
 59        if new_val < old_val:
 60            self._heapify_up(idx)
 61        else:
 62            self._heapify_down(idx)
 63
 64
 65def solve(operations):
 66    heap = MinHeap()
 67    inserted = {}
 68    counter = 0
 69
 70    for op in operations:
 71        command = op[0]
 72
 73        if command == "I":
 74            x = int(op[1])
 75            counter += 1
 76            inserted[counter] = x
 77            heap.insert(x, counter)
 78
 79        elif command == "PM":
 80            print(heap.get_min())
 81
 82        elif command == "DM":
 83            heap.delete_min()
 84
 85        elif command == "D":
 86            k = int(op[1])
 87            if k in inserted:
 88                heap.delete_kth(k)
 89
 90        elif command == "C":
 91            k = int(op[1])
 92            new_val = int(op[2])
 93            if k in inserted:
 94                inserted[k] = new_val
 95                heap.modify_kth(k, new_val)
 96
 97
 98data = sys.stdin.read().strip().splitlines()
 99n = int(data[0])
100operations = [line.split() for line in data[1:]]
101solve(operations)

840. 模拟散列表 | 原题链接

 1import sys
 2
 3N = int(1e5)
 4hashmap = [None] * N
 5
 6
 7def insert(x):
 8    pos = x % N
 9    while hashmap[pos] is not None:
10        pos = (pos + 1) % N
11    hashmap[pos] = x
12
13
14def query(x):
15    pos = x % N
16    while hashmap[pos] is not None:
17        if hashmap[pos] == x:
18            print("Yes")
19            return
20        pos = (pos + 1) % N
21    print("No")
22
23
24functions = {"I": insert, "Q": query}
25data = sys.stdin.read().strip().splitlines()
26n = int(data[0])
27for i in range(1, n + 1):
28    line = data[i].strip().split()
29    functions[line[0]](int(line[1]))

841. 字符串哈希 | 原题链接

暴力 TLE(11/13):

1import sys
2
3data = sys.stdin.read().strip().splitlines()
4_, m = map(int, data[0].strip().split())
5s = data[1].strip()
6for i in range(2, m + 2):
7    l1, r1, l2, r2 = map(int, data[i].strip().split())
8    print("Yes" if s[l1 - 1 : r1] == s[l2 - 1 : r2] else "No")

字符串哈希的基本思路是将字符串转化为一个整数哈希值,预处理计算所有前缀的哈希值,然后通过哈希值的差来快速计算任意子串的哈希值。具体做法是将字符串转化为 $P$ 进制数,同时为了防止数据太大,存储时要对 $Q$ 取模。按照经验,当 $P=131$ 或 $P=13331$ 且 $Q=2^{64}$ 时发生哈希冲突的概率极低。

使用 weight 存储 $P$ 进制中每一位的位权,即 $weight_i=weight_{i-1}\cdot P,weight_0=1$;hash_table 存储字符串前缀的哈希,有 hash_table[i] = hash_table[i - 1] * P + ord(string[i - 1])。计算某一段字符串哈希的方法是 hash = hash_table[r] - hash_table[l - 1] * weight[r - l + 1]

以十进制 $1435$ 为例,weight = [1, 10, 100, 1000], hash_table = [1, 14, 145, 1453],这个数字第 $3$、$4$ 位的数字可以表示为 $35=1453-14\times10^2$。

另一个需要注意的问题是,不能有字符映射到 $0$。假设字符 a 映射到 $0$,那么即使 aaaa 不同,他们的哈希值都是 $0$。

 1import sys
 2
 3MOD = 2**64
 4BASE = 131
 5
 6data = sys.stdin.read().strip().splitlines()
 7n, m = map(int, data[0].strip().split())
 8s = data[1].strip()
 9
10hash_table = [0] * (n + 1)
11weight = [1] * (n + 1)
12for i in range(1, n + 1):
13    hash_table[i] = (hash_table[i - 1] * BASE + ord(s[i - 1])) % MOD
14    weight[i] = (weight[i - 1] * BASE) % MOD
15
16for i in range(2, m + 2):
17    l1, r1, l2, r2 = map(int, data[i].strip().split())
18    hash1 = (hash_table[r1] - hash_table[l1 - 1] * weight[r1 - l1 + 1] % MOD) % MOD
19    hash2 = (hash_table[r2] - hash_table[l2 - 1] * weight[r2 - l2 + 1] % MOD) % MOD
20    print("Yes" if hash1 == hash2 else "No")

系列文章