动态规划练习(二)

文章目录

题目来自 AcWing


1050. 鸣人的影分身 | 原题链接

数据量很小,深搜可以 AC。

 1t = int(input().strip())
 2
 3
 4def dfs(last_pow, rest_pow, rest_cnt):
 5    global plans
 6    if rest_pow == 0 and rest_cnt == 0:
 7        plans += 1
 8        return
 9
10    if rest_pow < 0 or rest_cnt < 0:
11        return
12
13    if rest_cnt * last_pow > rest_pow:
14        return
15
16    for p in range(last_pow, rest_pow + 1):
17        dfs(p, rest_pow - p, rest_cnt - 1)
18
19
20for _ in range(t):
21    m, n = map(int, input().strip().split())
22    plans = 0
23    dfs(0, m, n)
24    print(plans)

dp 写法:令 dp[i][j] 表示将 j 单位能量分给 i 个分身的方法数。对于每一个 dp[i][j] 都需要考虑以下情况:

  • 前面 i 个分身中有至少一个分身没有分到能量,此时去掉没有能量的那个分身,方法数不变,因此此时方法数同 i - 1 个分身分配 j 能量的方法数,即 dp[i - 1][j]
  • 前面 i 个分身都有能量,如果将所有分身的能量都减一,方法数不变,已经分配的能量要减去已有分身数量,即 dp[i][j - i]
  • i < j,也就是能量数量小于分身数量时,前面一定有至少一个分身没有能量,所以只能选 dp[i - 1][j]

至于初始化,需要将所有 dp[i][0] 初始化为 1。

 1t = int(input().strip())
 2for _ in range(t):
 3    m, n = map(int, input().strip().split())
 4    dp = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
 5    for i in range(n + 1):
 6        dp[i][0] = 1
 7
 8    for i in range(1, n + 1):
 9        for j in range(1, m + 1):
10            if j >= i:
11                dp[i][j] = dp[i - 1][j] + dp[i][j - i]
12            else:
13                dp[i][j] = dp[i - 1][j]
14    print(dp[-1][-1])

1047. 糖果 | 原题链接

定义 dp[i][j] 表示在前 i 包糖果中,选择的糖果数量对 k 取模为 j 时的最大值。dp[i][j] 要在以下两种情况中取最大值:

  • 不取第 i 包糖果,方法数和 i - 1 包糖果方法数一样且模数不变,即 dp[i - 1][j]
  • 取第 i 包糖果,此时方法数为 i - 1 包糖果,且模数为当前模数减去第 i 包糖果数量的模数,即 dp[i - 1][(j + (k - a[i]) % k)] + a[i]

而对于初始化,dp[0][0] 为 0,而其他位置都应当是负无穷。因为只有 0 个糖果时,余数只能为 0 而不能为其他值,为了不影响计算而这样初始化。

 1n, k = map(int, input().strip().split())
 2nums = [0]
 3for _ in range(n):
 4    nums.append(int(input().strip()))
 5
 6dp = [[-float("inf") for _ in range(k)] for _ in range(n + 1)]
 7dp[0][0] = 0
 8for i in range(1, n + 1):
 9    for j in range(k):
10        dp[i][j] = max(dp[i - 1][j], dp[i - 1][(j + k - (nums[i] % k)) % k] + nums[i])
11print(dp[n][0])

1222. 密码脱落 | 原题链接

这道题是一道区间 DP 问题。题目要求的脱落字符个数,等于现字符串长度减去最大回文子串的长度。

dp[i][j] 表示左右端点分别为 ij 时的最长回文子串的长度。

  • i == j 时,最长回文子串长度为 0,因此 dp[i][j] = 0
  • j - i < 2 时,即当前子串长度为 1 时,该子串一定回文,因此 dp[i][j] = 1
  • s[i] == s[j] 时,由于该位置一定满足回文,当前回文子串长度从 s[i + 1: j] 位置的子串继承来,dp[i][j] = dp[i + 1][j - 1] + 2
  • s[i] != s[j] 时,当前回文子串长度取 s[i + 1: j + 1]s[i][j] 中回文子串最长的一个,所以 dp[i][j] = max(dp[i + 1][j], dp[i][j - 1])

在初始化时,将 dp 中的全部值初始化为 1,而 i == j 的位置单独初始化为 0。遍历的顺序为先遍历子序列长度 lengthrange(2, len(s) + 1)),内层遍历左端点 irange(n - length + 1)),最后计算出右端点 j = i + length - 1

 1s = input().strip()
 2n = len(s)
 3
 4dp = [[0 for _ in range(n)] for _ in range(n)]
 5for i in range(n):
 6    dp[i][i] = 1
 7
 8for length in range(2, n + 1):
 9    for i in range(n - length + 1):
10        j = i + length - 1
11        if s[i] == s[j]:
12            dp[i][j] = dp[i + 1][j - 1] + 2
13        else:
14            dp[i][j] = max(dp[i + 1][j], dp[i][j - 1])
15print(n - dp[0][n - 1])

另一种更常见写法是从大到小遍历 i(因为状态从 i + 1 位置转移到 i 位置),从小到大遍历 j(原因同上),其余内容不变。

 1s = input().strip()
 2n = len(s)
 3dp = [[0 for _ in range(n)] for _ in range(n)]
 4
 5for i in range(n)[::-1]:
 6    dp[i][i] = 1
 7    for j in range(i + 1, n):
 8        if s[i] == s[j]:
 9            dp[i][j] = dp[i + 1][j - 1] + 2
10        else:
11            dp[i][j] = max(dp[i + 1][j], dp[i][j - 1])
12
13print(n - dp[0][n - 1])

求最长回文区间也可以使用深搜解决。例如对于字符串 eacbba,由于左右两端的字符不一样,因此左右两端的字符只能选择一个作为回文子串的一部分,也就是要从 eacbbacbba 中选择一个最长子串。而在 acbba 中,左右两端的字符相同,因此这一字符串中最长回文子串的长度为 cbb 中最长回文子串长度加二。

 1from collections import defaultdict
 2
 3s = input().strip()
 4n = len(s)
 5memo = defaultdict(int)
 6
 7
 8def dfs(i, j):
 9    if i == j:
10        return 1
11    elif i == j + 1:
12        return 0
13
14    if memo[(i, j)] != 0:
15        return memo[(i, j)]
16    if s[i] == s[j]:
17        l = dfs(i + 1, j - 1) + 2
18    else:
19        l = max(dfs(i + 1, j), dfs(i, j - 1))
20    memo[(i, j)] = l
21    return l
22
23
24print(n - dfs(0, n - 1))

突然意识到 cachelru_cache 也能当记忆化搜索使……

 1from functools import lru_cache
 2import sys
 3
 4sys.setrecursionlimit(10000)
 5s = input().strip()
 6n = len(s)
 7
 8
 9@lru_cache(maxsize=None)
10def dfs(i, j):
11    if i == j:
12        return 1
13    elif i == j + 1:
14        return 0
15
16    if s[i] == s[j]:
17        return dfs(i + 1, j - 1) + 2
18    else:
19        return max(dfs(i + 1, j), dfs(i, j - 1))
20
21
22print(n - dfs(0, n - 1))

1220. 生命之树 | 原题链接

这道题是一道树形 DP 问题。题目要求是在无向图中找到一个连通块,使其节点权值之和最大。

树形 DP 往往与深搜配合使用。令 dp[i] 表示以第 i 个节点为根的子树最大权值之和。对于节点 i,有如下两种选择:

  • 不选其子树,此时这个节点上的最大权值就是他自身,也就是 dp[i] = nodes[i]
  • 选择其子树,此时这个节点上的最大权值是它自身权值加上子树权值,也就是 dp[i] = nodes[i] + dp[child]

由于是选择连通块,因此只要 dp[child] > 0 即可选择以 child 为根的子树。

另外代码中有深搜,如果用 Python 写就会报错 RecursionError,因此选用 Java 解决(好久没写了,顺带练手,反正考试要考)。

 1import java.util.Scanner;
 2import java.util.List;
 3import java.util.ArrayList;
 4
 5public class Main {
 6
 7    private static int[] nodes;
 8    private static List<List<Integer>> tree;
 9    private static long[] dp;
10
11    public static void main(String[] args) {
12        Scanner sc = new Scanner(System.in);
13        int n = sc.nextInt();
14        nodes = new int[n];
15        for (int i = 0; i < n; i++) {
16            nodes[i] = sc.nextInt();
17        }
18        tree = new ArrayList<List<Integer>>(n);
19        for (int i = 0; i < n; i++) {
20            tree.add(new ArrayList<Integer>());
21        }
22        for (int i = 0; i < 2 * (n - 1); i += 2) {
23            int u = sc.nextInt() - 1;
24            int v = sc.nextInt() - 1;
25            tree.get(u).add(v);
26            tree.get(v).add(u);
27        }
28        sc.close();
29
30        dp = new long[n];
31        dfs(0, -1);
32        long max_val = Long.MIN_VALUE;
33        for (long v : dp) {
34            max_val = Math.max(max_val, v);
35        }
36        System.out.println(max_val);
37    }
38
39    public static void dfs(int curr, int last) {
40        dp[curr] = nodes[curr];
41        for (int node : tree.get(curr)) {
42            if (node == last) {
43                continue;
44            }
45            dfs(node, curr);
46            if (dp[node] > 0) {
47                dp[curr] += dp[node];
48            }
49        }
50    }
51}

1226. 包子凑数 | 原题链接

这道题相当于是1205. 买不到的数目的进阶版,买不到的数目中给出两个数,而这道题会给出多个数。

已知对于两个数 $a,b$,这两个数最大“买不到的数目”为 $(a-1)(b-1)$,而如果是三个或者更多,这个数字反而会更小。因此对于上界为 $100$ 的若干个数,其 Frobenius 数一定不超过 $10000$。

我们不妨先尝试使用一种更容易理解的做法:令所有基数组成的集合为 $A$,所有“可组成的数”的集合为 $B$,假设 $m\in B$,那么 $\exist k\in A,m-k\in B$。因此只需从 $0$($0\in B$)开始依次枚举所有数即可,这样做大约会循环 $10^6$ 层,在时间上可接受。

 1import sys
 2from math import gcd
 3from functools import reduce
 4
 5data = list(map(int, sys.stdin.read().strip().split("\n")))
 6n = data[0]
 7nums = data[1:]
 8
 9g = reduce(gcd, nums)
10if g != 1:
11    print("INF")
12    exit()
13lcm = lambda x, y: x * y // gcd(x, y)
14l = min(10000, reduce(lcm, nums))
15
16status = [False] * l
17status[0] = True
18for i in range(l):
19    for num in nums:
20        if i >= num:
21            status[i] |= status[i - num]
22        if status[i]:
23            break
24print(status.count(False))

这是 Python 3.8 及更低版本的 Python 中的写法。对于 3.9 及更高版本的 Python,math.lcm 可以导入,而且 math.gcdmath.lcm 都可以接收两个以上参数。

下面是一种贪心写法,类似于买不到的数目。

seq 列表存储可以组合出的数;动态维护一个列表 nxt,用于存储下一个可能的组合值;用 ptr 列表维护每一个数字在 seq 中的下标,用于生成下一个组合。每一次迭代时,找出 nxt 中的最小值追加到 seq 列表中,随后根据 ptr 中指向的 seq 中的值求出下一个可能的组合值并更新到 nxt 中。

 1import sys
 2from math import gcd
 3from functools import reduce
 4
 5data = list(map(int, sys.stdin.read().strip().split("\n")))
 6n = data[0]
 7nums = data[1:]
 8
 9g = reduce(gcd, nums)
10if g != 1:
11    print("INF")
12    exit()
13lcm = lambda x, y: x * y // gcd(x, y)
14l = min(10010, reduce(lcm, nums))
15
16seq = [0]
17nxt = nums[:]
18ptr = [0] * n
19while seq[-1] < l:
20    mini = min(nxt)
21    seq.append(mini)
22    for i in range(n):
23        if mini == nxt[i]:
24            ptr[i] += 1
25            nxt[i] = nums[i] + seq[ptr[i]]
26print(l - len(seq) + 1)

除此之外,这道题也可以用广搜解决。如果一个数是“可组成的”,那么它加上任意一个 $a_i$ 得到的数都是“可组成的”。从 $0$ 开始搜索并标记每一个“可组成的”数,剩下的就是题目要求找出的数。为了防止程序做重复工作,如果某一个数已经被标记过,就不需要再次遍历通过这个数得到的数了。

 1import sys
 2from math import gcd
 3from functools import reduce
 4from collections import deque
 5
 6data = list(map(int, sys.stdin.read().strip().split("\n")))
 7n = data[0]
 8nums = data[1:]
 9
10g = reduce(gcd, nums)
11if g != 1:
12    print("INF")
13    exit()
14lcm = lambda x, y: x * y // gcd(x, y)
15l = min(10000, reduce(lcm, nums))
16
17status = [False] * l
18q = deque((0,))
19while q:
20    num = q.popleft()
21    if status[num]:
22        continue
23    status[num] = True
24    for k in nums:
25        nxt = num + k
26        if nxt < l:
27            q.append(nxt)
28print(status.count(False))

1070. 括号配对 | 原题链接

这道题也是区间 DP 问题,有些类似于密码脱落,区别在于这道题里 []() 也是一种合法组合。所以当某一区间左右两端的括号不匹配时,其最长合法子序列的长度应当为区间内某一点左右两侧合法子序列的长度之和,即这一点需要在左右两端点遍历。

同样,这道题有两种写法:

DFS 写法:

 1from functools import lru_cache
 2
 3s = input().strip()
 4n = len(s)
 5
 6@lru_cache(maxsize=None)
 7def dfs(i, j):
 8    if i >= j:
 9        return 0
10
11    length = 0
12    if (s[i], s[j]) in (('(', ')'), ('[', ']')):
13        length = dfs(i + 1, j - 1) + 2
14    for k in range(i, j):
15        length = max(length, dfs(i, k) + dfs(k + 1, j))
16    return length
17
18print(n - dfs(0, n - 1))

DP 写法:

 1s = input().strip()
 2n = len(s)
 3
 4dp = [[0 for _ in range(n)] for _ in range(n)]
 5for i in range(n)[::-1]:
 6    for j in range(i + 1, n):
 7        if (s[i], s[j]) in (('(', ')'), ('[', ']')):
 8            dp[i][j] = dp[i + 1][j - 1] + 2
 9        for k in range(i, j):
10            dp[i][j] = max(dp[i][j], dp[i][k] + dp[k + 1][j])
11
12print(n - dp[0][n - 1])

1078. 旅游规划 | 原题链接

之前在1207. 大臣的旅费中遇到过树的直径,当时使用了两次深搜以及两次广搜解决问题,因此我想到用两次广搜找出树的直径再找到路径上的节点。但是这棵树可能有多条直径,而要想找到所有直径上的点,两次广搜可是不够的……用多次广搜只能骗分,没法 AC。

 1import sys
 2from collections import defaultdict, deque
 3
 4data = sys.stdin.read().strip().split("\n")
 5n = int(data[0].strip())
 6tree = defaultdict(list)
 7for i in range(1, n):
 8    a, b = map(int, data[i].strip().split())
 9    tree[a].append(b)
10    tree[b].append(a)
11
12def bfs(node):
13    visited = [False] * (n)
14    depth = defaultdict(set)
15    q = deque()
16    q.append((node, [node]))
17    visited[node] = True
18    max_len = 0
19    max_path = list()
20    while q:
21        curr, path = q.popleft()
22        visited[curr] = True
23        has_next = False
24        for nxt in tree[curr]:
25            if not visited[nxt]:
26                has_next = True
27                q.append((nxt, path + [nxt]))
28        if not has_next:
29            depth[len(path)] |= set(path)
30            if len(path) > max_len:
31                max_len = len(path)
32                max_path = path
33    return depth
34
35d = bfs(0)
36length = max(tuple(d.keys()))
37nodes = tuple(d[length])
38ans = set()
39max_l = 0
40for node in nodes:
41    p = bfs(node)
42    ls = max(tuple(p.keys()))
43    if ls == max_l:
44        ans |= p[ls]
45    elif ls > max_l:
46        max_l = ls
47        ans = p[ls]
48print(*sorted(ans), sep="\n")

首先先搜一次,找到全部满足“最远点”的端点,然后第二次广搜从这些端点开始,找出最长路径上的全部节点。也就是说第一次广搜能搜到多少节点,后面就要再进行多少轮广搜。这样能拿 50% 得分。

其实除了两次深搜/广搜外,使用树形 DP 也可以求出树的直径。

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().split("\n")
 5n = int(data[0].strip())
 6tree = defaultdict(list)
 7for i in range(1, n):
 8    u, v = map(int, data[i].strip().split())
 9    tree[u].append(v)
10    tree[v].append(u)
11ans = 0
12
13
14def dfs_depth(last, curr):
15    global ans
16    children = list()
17    for node in tree[curr]:
18        if node != last:
19            children.append(dfs_depth(curr, node))
20    if len(children) == 0:
21        return 0
22    elif len(children) == 1:
23        t = children[0] + 1
24    else:
25        children.sort(reverse=True)
26        t = children[0] + children[1] + 2
27    ans = max(ans, t)
28    return children[0] + 1
29
30
31t = dfs_depth(-1, 0)
32print(ans)

这段代码用深搜控制节点访问顺序,每个节点都只会访问一次,因此时间复杂度是 $O(n)$。具体分析可见树形 DP:树的直径【基础算法精讲 23】

对这段代码稍作修改,就可以得到题解。下面的代码中,dfs_d 函数不仅要计算树的直径,还要记录每一个节点向下走的最长和第二长的路径长度;而 dfs_up 函数则用于判断每一个节点是否在直径,具体方法是判断其延伸出的最长两条路径长度之和是否等于树的直径。两次深搜中每一个节点都只会被访问一次,时间复杂度为 $O(n)$。详细题解见AcWing 1078. 旅游规划

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().split("\n")
 5n = int(data[0].strip())
 6tree = defaultdict(list)
 7for i in range(1, n):
 8    u, v = map(int, data[i].strip().split())
 9    tree[u].append(v)
10    tree[v].append(u)
11
12ans = 0
13d = [[0, 0] for _ in range(n)]
14n1 = [0] * n
15
16
17def dfs_d(last, curr):
18    global ans
19    children = list()
20    for node in tree[curr]:
21        if node != last:
22            children.append((dfs_d(curr, node), node))
23    if len(children) == 0:
24        return 0
25    elif len(children) == 1:
26        t = children[0][0] + 1
27        d[curr][0] = t
28        n1[curr] = children[0][1]
29    else:
30        children.sort(reverse=True)
31        t = children[0][0] + children[1][0] + 2
32        d[curr][0] = children[0][0] + 1
33        d[curr][1] = children[1][0] + 1
34        n1[curr] = children[0][1]
35    ans = max(ans, t)
36    return children[0][0] + 1
37
38
39up = [0] * n
40
41
42def dfs_up(parent, curr):
43    for node in tree[curr]:
44        if node == parent:
45            continue
46        up[node] = up[curr] + 1
47        if n1[curr] == node:
48            up[node] = max(up[node], d[curr][1] + 1)
49        else:
50            up[node] = max(up[node], d[curr][0] + 1)
51        dfs_up(curr, node)
52
53
54_ = dfs_d(-1, 0)
55dfs_up(-1, 0)
56print(ans)
57path = set()
58for i in range(n):
59    p = [d[i][0], d[i][1], up[i]]
60    p.sort(reverse=True)
61    if p[0] + p[1] == ans:
62        path.add(i)
63print(*sorted(path), sep="\n")

1243. 糖果 | 原题链接

这道题可以用状态压缩 DP解决。令 dp[i][j] 表示只考虑前 i 包糖果,且已取得的糖果状态为 j 时最少需要拿的糖果袋数。这里的状态 j 为使用二进制表示的糖果种类,例如在有 $5$ 种口味的糖且已经获得了 $2,3,5$ 三种口味糖时的状态表示为 10110,即对于每一包糖果,其状态表示为 reduce(lambda x, y: x | 1 << (y - 1), package, 0),因此 dp 数组的横坐标最大为 1 << M

状态转移方程:dp[i][j] = min(dp[i - 1][j], dp[i - 1][j & (~package[i])] + 1)dp[i - 1][j] 表示不拿第 i 包糖果时最小包数,即和只在前 i - 1 包中考虑时的最小包数相同;dp[i - 1][j & (~package[i])] + 1 中,pacakge[i] 表示第 i 包糖果的状态,j & (~package[i]) 表示拿第 i 包糖果前的状态。关于此状态转移的正确性,在这篇题解中有讨论。此外,状态转移方程也可以写作 dp[i][j] = min(dp[i - 1][j], dp[i - 1][j & (pacakge[i])] + 1)

这道题要求的是最小值,dp 初始化时应当尽可能大,而 dp[0][0] 初始化为 $1$,最终答案为 dp[N][(1 << M) - 1]。这道题的时空开销较大,所以需要用滚动数组优化且用 C++ 实现。

 1#include <iostream>
 2#include <cstring>
 3#include <algorithm>
 4
 5using namespace std;
 6
 7int dp[1048577] = { 0 };
 8int candies[105] = { 0 };
 9const int INF = 0x3f3f3f3f;
10
11int main() {
12    memset(dp, INF, sizeof(dp));
13    dp[0] = 0;
14    int N, M, K;
15    cin >> N >> M >> K;
16    for (int i = 1; i <= N; i++) {
17        int p = 0, q;
18        for (int j = 0; j < K; j++) {
19            cin >> q;
20            p |= 1 << (q - 1);
21        }
22        candies[i] = p;
23    }
24
25    for (int i = 1; i <= N; i++) {
26        for (int j = 0; j < 1 << M; j++) {
27            dp[j] = min(dp[j], dp[j & (~candies[i])] + 1);
28        }
29    }
30
31    cout << ((dp[(1 << M) - 1] == INF) ? -1 : dp[(1 << M) - 1]);
32    return 0;
33}

1217. 垒骰子 | 原题链接

这道题最直接的思路是用深搜从下往上搜一遍。由于每一个骰子在确定好上下面之后,四个面可以旋转,所以每一层搜出来的结果都有乘四再加到前一层去。

 1from collections import defaultdict
 2
 3MOD = int(1e9 + 7)
 4path = {0: 0, 1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3}
 5banned = defaultdict(set)
 6n, m = map(int, input().strip().split())
 7for _ in range(m):
 8    a, b = map(int, input().strip().split())
 9    banned[a].add(b)
10    banned[b].add(a)
11banned[0] = set()
12
13
14def dfs(down, depth):
15    if depth == n:
16        return 1
17
18    up = path[down]
19    count = 0
20    for i in range(1, 7):
21        if i not in banned[up]:
22            count += 4 * dfs(i, depth + 1)
23            count %= MOD
24    return count
25
26
27print(dfs(0, 0) % MOD)

这样写能过 7/10 个测试点。每一个骰子都是一层递归,而骰子的数据范围在 $10^9$,就算是智子来了也得爆栈。

把这段代码改成 DP,也许能多过几个测试点……吧。

 1from collections import defaultdict
 2
 3MOD = int(1e9 + 7)
 4path = {0: 0, 1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3}
 5banned = defaultdict(set)
 6n, m = map(int, input().strip().split())
 7for _ in range(m):
 8    a, b = map(int, input().strip().split())
 9    banned[a].add(b)
10    banned[b].add(a)
11banned[0] = set()
12
13dp = [[0 for _ in range(7)] for _ in range(n)]
14dp[0] = [0, 1, 1, 1, 1, 1, 1]
15for i in range(1, n):
16    for j in range(1, 7):
17        for down in range(1, 7):
18            if down not in banned[j]:
19                dp[i][path[down]] += dp[i - 1][j] * 4
20                dp[i][j] %= MOD
21
22print((sum(dp[-1]) * 4) % MOD)

这次还是过了 7 个测试点,但是不是 RE 而是 TLE。

观察发现相邻两层的状态矩阵存在递推关系,因此可以用矩阵快速幂优化:

 1from collections import defaultdict
 2
 3MOD = int(1e9 + 7)
 4MAT = [[4 for _ in range(6)] for _ in range(6)]
 5f0 = [[1], [1], [1], [1], [1], [1]]
 6E = [[0 for _ in range(6)] for _ in range(6)]
 7for i in range(6):
 8    E[i][i] = 1
 9path = {1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3}
10banned = defaultdict(set)
11
12n, m = map(int, input().strip().split())
13for _ in range(m):
14    a, b = map(int, input().strip().split())
15    MAT[path[a] - 1][b - 1] = 0
16    MAT[path[b] - 1][a - 1] = 0
17
18
19def mat_mul(a, b):
20    x = len(a)
21    y = len(b[0])
22    ans = [[0 for _ in range(y)] for _ in range(x)]
23    for i in range(x):
24        for j in range(y):
25            for k in range(len(b)):
26                ans[i][j] += (a[i][k] * b[k][j]) % MOD
27    return ans
28
29
30n -= 1
31while n:
32    if n & 1:
33        E = mat_mul(MAT, E)
34    MAT = mat_mul(MAT, MAT)
35    n >>= 1
36res = mat_mul(E, f0)
37
38ans = 0
39for row in res:
40    ans += sum(row) % MOD
41print((ans * 4) % MOD)

详细解释可见AcWing 1217. 【DFS, 动态规划】垒骰子

系列文章