图论练习

文章目录

题目来自 AcWing


848. 有向图的拓扑序列 | 原题链接

拓扑排序已经做过好几遍了,而且这还是一道模板题。如果一个有向图中有环(自环也是环),那么它一定无法拓扑排序。如果一个图拓扑排序得到的节点数和图中的节点数不同,那么这个图无法被拓扑排序。又因为“有向图可以被拓扑排序”和“图是有向无环图”互为充要条件,所以,则这个图一定不是有向无环图。

 1import sys
 2from collections import defaultdict, deque
 3
 4data = sys.stdin.read().strip().splitlines()
 5n, m = map(int, data[0].strip().split())
 6
 7indegree = [0 for _ in range(n + 1)]
 8graph = defaultdict(list)
 9ans = list()
10
11for i in range(1, m + 1):
12    u, v = map(int, data[i].strip().split())
13    graph[u].append(v)
14    indegree[v] += 1
15
16q = deque()
17for node in range(1, n + 1):
18    if indegree[node] == 0:
19        q.append(node)
20
21ans = list()
22while q:
23    node = q.popleft()
24    ans.append(node)
25    for nxt in graph[node]:
26        indegree[nxt] -= 1
27        if indegree[nxt] == 0:
28            q.append(nxt)
29
30if len(ans) == n:
31    print(*ans)
32else:
33    print(-1)

849. Dijkstra求最短路 I | 原题链接

朴素 Dijkstra 模板题,需要注意当前轮确定的节点不一定有后继节点,当节点的后继节点不存在时,下一轮的状态继承当前轮的状态。具体分析参见图-最短路径-Dijkstra算法AcWing 849. Dijkstra求最短路 I:图解 详细代码

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().splitlines()
 5n, m = map(int, data[0].strip().split())
 6graph = defaultdict(list)
 7for i in range(1, m + 1):
 8    x, y, z = map(int, data[i].strip().split())
 9    graph[x].append((y, z))
10
11status = [False] * (n + 1)
12dist = [float("inf")] * (n + 1)
13path = [None] * (n + 1)
14
15dist[1] = 0
16path[1] = [1]
17for _ in range(n):
18    min_dist = float("inf")
19    min_node = None
20    for node in range(1, n + 1):
21        if not status[node] and dist[node] < min_dist:
22            min_dist = dist[node]
23            min_node = node
24
25    if min_node is not None:
26        status[min_node] = True
27
28    for next_node, next_dist in graph[min_node]:
29        if not status[next_node] and dist[min_node] + next_dist < dist[next_node]:
30            dist[next_node] = dist[min_node] + next_dist
31            path[next_node] = path[min_node] + [next_node]
32print(dist[-1] if status[-1] else -1)

题目中没有要求输出路径,但是这段代码也可以记录路径。

850. Dijkstra求最短路 II | 原题链接

这道题数据范围更大,给出的图是稀疏图,需要用堆优化 Dijkstra 解决。用 Python 会 MLE:

 1import sys
 2from collections import defaultdict
 3import heapq
 4
 5data = sys.stdin.read().strip().splitlines()
 6n, m = map(int, data[0].strip().split())
 7graph = defaultdict(list)
 8for i in range(1, m + 1):
 9    x, y, z = map(int, data[i].strip().split())
10    graph[x].append((y, z))
11
12status = [False] * (n + 1)
13dist = [float("inf")] * (n + 1)
14path = [None] * (n + 1)
15
16dist[1] = 0
17path[1] = [1]
18
19pq = [(0, 1)]
20
21while pq:
22    curr_dist, curr_node = heapq.heappop(pq)
23
24    if status[curr_node]:
25        continue
26
27    status[curr_node] = True
28
29    for next_node, next_dist in graph[curr_node]:
30        if not status[next_node] and curr_dist + next_dist < dist[next_node]:
31            dist[next_node] = curr_dist + next_dist
32            path[next_node] = path[curr_node] + [next_node]
33            heapq.heappush(pq, (dist[next_node], next_node))
34
35print(dist[-1] if status[-1] else -1)

如果不用 Python,用邻接矩阵或者邻接表存一个图会很麻烦,所以采用链式前向星来存储这个图,并配合堆优化 Dijkstra 解题:

 1import java.util.*;
 2
 3public class Main {
 4
 5    static class E {
 6        public int to, next;
 7        public long w;
 8
 9        public E(int to, long w, int next) {
10            this.to = to;
11            this.w = w;
12            this.next = next;
13        }
14    }
15
16    static E[] edges;
17    static int[] head;
18    static int cnt = 0;
19
20    public static void main(String args[]) {
21        int n, m;
22        try (Scanner sc = new Scanner(System.in)) {
23            n = sc.nextInt();
24            m = sc.nextInt();
25            edges = new E[m + 1];
26            head = new int[n + 1];
27            Arrays.fill(head, -1);
28            for (int i = 1; i <= m; i++) {
29                int x = sc.nextInt(), y = sc.nextInt();
30                long z = sc.nextLong();
31                addEdge(x, y, z);
32            }
33        }
34
35        boolean[] status = new boolean[n + 1];
36        long[] dist = new long[n + 1];
37        Arrays.fill(dist, Long.MAX_VALUE);
38        dist[1] = 0;
39
40        PriorityQueue<long[]> pq = new PriorityQueue<>(Comparator.comparingLong(a -> a[0]));
41        pq.add(new long[] { 0, 1 });
42
43        while (!pq.isEmpty()) {
44            long[] curr = pq.poll();
45            long curr_dist = curr[0];
46            int curr_node = (int) curr[1];
47
48            if (status[curr_node]) {
49                continue;
50            }
51            status[curr_node] = true;
52
53            for (int i = head[curr_node]; i != -1; i = edges[i].next) {
54                E edge = edges[i];
55                if (edge == null) {
56                    continue;
57                }
58                int next_node = edge.to;
59                long next_dist = edge.w;
60                if (!status[next_node] && curr_dist + next_dist < dist[next_node]) {
61                    dist[next_node] = curr_dist + next_dist;
62                    pq.add(new long[] { dist[next_node], next_node });
63                }
64            }
65        }
66
67        System.out.println(status[n] ? dist[n] : -1);
68    }
69
70    public static void addEdge(int u, int v, long w) {
71        edges[++cnt] = new E(v, w, head[u]);
72        head[u] = cnt;
73    }
74}

853. 有边数限制的最短路 | 原题链接

一般 Bellman-Ford 算法写法:

 1import sys
 2from collections import defaultdict
 3
 4INF = float("inf")
 5data = sys.stdin.read().strip().splitlines()
 6n, m, k = map(int, data[0].strip().split())
 7graph = defaultdict(list)
 8for i in range(1, m + 1):
 9    x, y, z = map(int, data[i].strip().split())
10    graph[x].append((y, z))
11
12dist = [INF] * (n + 1)
13dist[1] = 0
14for _ in range(n - 1):
15    for curr_node in range(1, n + 1):
16        for next_node, next_dist in graph[curr_node]:
17            dist[next_node] = min(dist[next_node], dist[curr_node] + next_dist)
18
19
20for curr_node in range(1, n + 1):
21    for next_node, next_dist in graph[curr_node]:
22        if dist[curr_node] + next_dist < dist[next_node]:
23            dist[next_node] = -INF
24
25print(*dist[1:])

Dijkstra 算法无法处理带有负权边的情况,此时就要用到 Bellman-Ford 算法。该算法时间复杂度为 $O(EV)$,其中 $E$ 为边数,$V$ 为节点数。而堆优化后的 Dijkstra 算法时间复杂度为 $(O(E+V)\log V)$。

在 Bellman-Ford 最短路算法中,遍历边的顺序不重要。因为有效的距离更新是从出发点开始向外扩散的,因此循环了 $k$ 次时恰好能更新到距离出发点 $k$ 条边的节点。

 1import sys
 2
 3INF = float("inf")
 4data = sys.stdin.read().strip().splitlines()
 5n, m, k = map(int, data[0].strip().split())
 6edges = list()
 7for i in range(1, m + 1):
 8    edges.append(list(map(int, data[i].strip().split())))
 9
10dist = [INF] * (n + 1)
11dist[1] = 0
12for _ in range(k):
13    backup = dist.copy()
14    for edge in edges:
15        u, v, w = edge
16        dist[v] = min(dist[v], backup[u] + w)
17
18print(dist[-1] if dist[-1] != INF else "impossible")

backup 数组是为了防止串联导致距离更新错误。

851. spfa求最短路 | 原题链接

这道题用到 SPFA 算法,或者说是队列优化的 Bellman-Ford 算法。因为只有当某一节点的 dist 减小,其后继节点的 dist 才有可能减小。Bellman-Ford 算法的时间复杂度为 $O(EV)$,此算法的时间复杂度范围是 $O(E+V)$ 到 $O(EV)$。

 1import sys
 2from collections import defaultdict, deque
 3
 4INF = float("inf")
 5data = sys.stdin.read().strip().splitlines()
 6n, m = map(int, data[0].strip().split())
 7graph = defaultdict(list)
 8for i in range(1, m + 1):
 9    x, y, z = map(int, data[i].strip().split())
10    graph[x].append((y, z))
11
12dist = [INF] * (n + 1)
13dist[1] = 0
14in_queue = [False] * (n + 1)
15in_queue[1] = True
16
17q = deque([1])
18while q:
19    curr_node = q.popleft()
20    in_queue[curr_node] = False
21    for next_node, next_dist in graph[curr_node]:
22        if dist[curr_node] + next_dist < dist[next_node]:
23            dist[next_node] = dist[curr_node] + next_dist
24            if not in_queue[next_node]:
25                q.append(next_node)
26                in_queue[next_node] = True
27
28print(dist[-1] if dist[-1] != INF else "impossible")

上面这段代码可以 AC,但是无法处理带有负环(负权回路)的情况。如果图中有负环,那么程序仍会陷入死循环。

852. spfa判断负环 | 原题链接

SPFA 判断负环的方式也很简单:记录每个节点入队的次数,如果某个节点入队了 $n$ 次,那么这个图中一定有负环。

然而,题目中给出的图不一定是连通图,而 SPFA 算法只能用于求单源最短路径,因此如果仅从 $1$ 号节点出发可能找不到图中的负环。为了解决这个问题,我们不妨建立一个虚拟源节点,并且从这个节点向其他各个节点连接一条权值为 $0$ 的有向边,此时新图中含有负环等价于原图中含有负环。

 1import sys
 2from collections import defaultdict, deque
 3
 4INF = float("inf")
 5data = sys.stdin.read().strip().splitlines()
 6graph = defaultdict(list)
 7n, m = map(int, data[0].strip().split())
 8for i in range(1, m + 1):
 9    x, y, z = map(int, data[i].strip().split())
10    graph[x].append((y, z))
11for i in range(1, n + 1):
12    graph[0].append((i, 0))
13
14dist = [INF] * (n + 1)
15dist[0] = 0
16inq = [False] * (n + 1)
17inq[0] = True
18count = [0] * (n + 1)
19
20q = deque((0,))
21while q:
22    curr_node = q.popleft()
23    inq[curr_node] = False
24    count[curr_node] += 1
25    if count[curr_node] >= n + 1:
26        print("Yes")
27        exit()
28    for next_node, next_dist in graph[curr_node]:
29        if dist[curr_node] + next_dist < dist[next_node]:
30            dist[next_node] = dist[curr_node] + next_dist
31            if not inq[next_node]:
32                q.append(next_node)
33                inq[next_node] = True
34
35print("No")

854. Floyd求最短路 | 原题链接

Floyd 算法修改邻接矩阵时的遍历顺序依次是中转节点、起始节点、目标节点。如果从起始节点到中转节点或者从中转节点到目标节点的距离为正无穷就直接 continue。注意图中可能存在重边,因此读取边的信息时,保留两节点间距离权重最小的边。

 1import sys
 2
 3INF = float("inf")
 4data = sys.stdin.read().strip().splitlines()
 5n, m, k = map(int, data[0].strip().split())
 6graph = [[INF if i != j else 0 for i in range(n + 1)] for j in range(n + 1)]
 7for i in range(1, m + 1):
 8    x, y, z = map(int, data[i].strip().split())
 9    graph[x][y] = min(graph[x][y], z)
10
11for nt in range(1, n + 1):
12    for ns in range(1, n + 1):
13        if graph[ns][nt] == INF:
14            continue
15        for nd in range(1, n + 1):
16            if graph[nt][nd] == INF:
17                continue
18            graph[ns][nd] = min(graph[ns][nd], graph[ns][nt] + graph[nt][nd])
19
20for i in range(m + 1, m + k + 1):
21    x, y = map(int, data[i].strip().split())
22    print(graph[x][y] if graph[x][y] != INF else "impossible")

最短路算法比较

算法 时间复杂度 最短路类型 作用于 能否检测负环
Dijkstra

$O((V + E) \log V)$

单源最短路 非负权图,稠密图 不能
堆优化 Dijkstra $O(E + V \log V)$ 单源最短路 非负权图,稀疏图 不能
Bellman-Ford $O(VE)$ 单源最短路 任意图,稠密图
SPFA $O(E)$(平均情况)
$O(VE)$(最坏情况)
单源最短路 任意图,稀疏图
Floyd $O(V^3)$ 多源最短路 任意图

858. Prim算法求最小生成树 | 原题链接

Prim 算法又被称为“加点法”,即每次选择距离已连通部分最近的节点并将其添加到连通部分中。

 1import sys
 2from collections import defaultdict
 3
 4INF = float("inf")
 5data = sys.stdin.read().strip().splitlines()
 6graph = defaultdict(list)
 7n, m = map(int, data[0].strip().split())
 8status = [False] * (n + 1)
 9dist = [INF] * (n + 1)
10pre = [-1] * (n + 1)
11for i in range(1, m + 1):
12    u, v, w = map(int, data[i].strip().split())
13    graph[u].append((v, w))
14    graph[v].append((u, w))
15
16dist[1] = 0
17total_w = 0
18for i in range(n):
19    min_dist = INF
20    min_node = -1
21    for node in range(1, n + 1):
22        if not status[node] and dist[node] < min_dist:
23            min_dist = dist[node]
24            min_node = node
25    status[min_node] = True
26    total_w += dist[min_node]
27
28    for next_node, next_dist in graph[min_node]:
29        if not status[next_node] and next_dist < dist[next_node]:
30            dist[next_node] = next_dist
31            pre[next_node] = min_node
32
33print(total_w if all(status[1:]) else "impossible")

上面是朴素 Prim 算法的写法,用这段代码足以 AC。上面这段代码与 Dijkstra 算法的代码极为相似,核心区别在于此算法的 dist 用于存储某节点到连通部分的最短距离。看上去 Prim 算法也可以使用堆优化来处理稀疏图的最小生成树,然而对于稀疏图的最小生成树,Kruskal 算法更为常用。

859. Kruskal算法求最小生成树 | 原题链接

Kruskal 算法又称为“加边法”,每次选取权值最小的边,如果这条边两侧的端点不连通,则将这条边添加到结果中,而判断两点是否连通则需要用到并查集。在合并的同时用 size 数组记录当前连通块的大小,当所有边都处理完时,如果最后一次处理的节点(即并查集的根节点)连通块大小与节点数量不相等,则说明有些节点没有被添加到最小生成树中,输出 "impossible"。

 1import sys
 2import heapq
 3
 4data = sys.stdin.read().strip().splitlines()
 5n, m = map(int, data[0].strip().split())
 6edges = list()
 7pre = [i for i in range(n + 1)]
 8size = [1] * (n + 1)
 9for i in range(1, m + 1):
10    u, v, w = map(int, data[i].strip().split())
11    heapq.heappush(edges, (w, u, v))
12
13
14def find(n):
15    if pre[n] != n:
16        pre[n] = find(pre[n])
17    return pre[n]
18
19
20total_w = 0
21last_v = -1
22while edges:
23    w, u, v = heapq.heappop(edges)
24    pu, pv = find(u), find(v)
25    if pu != pv:
26        pre[pu] = pv
27        size[pv] += size[pu]
28        last_v = pv
29        total_w += w
30
31print(total_w if size[last_v] == n else "impossible")

860. 染色法判定二分图 | 原题链接

用两种颜色给图中所有节点染色,如果能够确保每个节点与其邻接节点不同色,则这个图是二分图。二分图不一定是连通图,所以一定要遍历每一个节点。

广搜写法:

 1import sys
 2from collections import defaultdict, deque
 3
 4data = sys.stdin.read().strip().splitlines()
 5n, m = map(int, data[0].strip().split())
 6graph = defaultdict(list)
 7for i in range(1, m + 1):
 8    u, v = map(int, data[i].strip().split())
 9    graph[u].append(v)
10    graph[v].append(u)
11
12vert = [-1 for _ in range(n + 1)]
13for i in range(1, n + 1):
14    if vert[i] == -1:
15        vert[i] = 1
16        q = deque((i,))
17        while q:
18            curr_node = q.popleft()
19            for next_node in graph[curr_node]:
20                if vert[next_node] == -1:
21                    vert[next_node] = vert[curr_node] ^ 1
22                    q.append(next_node)
23                elif vert[next_node] == vert[curr_node]:
24                    print("No")
25                    exit()
26
27print("Yes" if vert[1:].count(-1) == 0 else "No")

广搜写法(Python 会爆栈):

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().splitlines()
 5n, m = map(int, data[0].strip().split())
 6graph = defaultdict(list)
 7for i in range(1, m + 1):
 8    u, v = map(int, data[i].strip().split())
 9    graph[u].append(v)
10    graph[v].append(u)
11vert = [-1 for _ in range(n + 1)]
12is_bipartite = True
13
14
15def dfs(curr_node, last_val):
16    global is_bipartite
17    if vert[curr_node] == last_val:
18        is_bipartite = False
19        return
20
21    if vert[curr_node] != -1:
22        return
23
24    vert[curr_node] = last_val ^ 1
25    for next_node in graph[curr_node]:
26        dfs(next_node, vert[curr_node])
27        if not is_bipartite:
28            return
29
30
31for i in range(1, n + 1):
32    if vert[i] == -1:
33        dfs(i, 0)
34
35print("Yes" if is_bipartite else "No")
 1import java.util.*;
 2
 3public class Main {
 4    static HashMap<Integer, ArrayList<Integer>> graph;
 5    static int[] vert;
 6    static boolean isBipartite = true;
 7
 8    public static void main(String[] args) {
 9        Scanner sc = new Scanner(System.in);
10        int n = sc.nextInt();
11        int m = sc.nextInt();
12        graph = new HashMap<Integer, ArrayList<Integer>>();
13        vert = new int[n + 1];
14        Arrays.fill(vert, -1);
15        while (m-- > 0) {
16            int u = sc.nextInt();
17            int v = sc.nextInt();
18
19            ArrayList<Integer> uNext = graph.getOrDefault(u, new ArrayList<Integer>());
20            uNext.add(v);
21            graph.put(u, uNext);
22
23            ArrayList<Integer> vNext = graph.getOrDefault(v, new ArrayList<Integer>());
24            vNext.add(u);
25            graph.put(v, vNext);
26        }
27        sc.close();
28
29        for (int i = 1; i <= n; i++) {
30            if (vert[i] == -1) {
31                dfs(i, 0);
32            }
33        }
34
35        System.out.println(isBipartite ? "Yes" : "No");
36    }
37
38    public static void dfs(int curr_node, int last_val) {
39        if (vert[curr_node] == last_val) {
40            isBipartite = false;
41            return;
42        }
43
44        if (vert[curr_node] != -1) {
45            return;
46        }
47
48        vert[curr_node] = last_val ^ 1;
49        for (int next_node : graph.getOrDefault(curr_node, new ArrayList<Integer>())) {
50            dfs(next_node, vert[curr_node]);
51            if (!isBipartite) {
52                return;
53            }
54        }
55    }
56}

861. 二分图的最大匹配 | 原题链接

求最大匹配的思路是,遍历左侧节点,同时让左侧每个节点尝试匹配右侧节点,如果右侧节点是未匹配的状态则直接匹配;如果右侧节点已经匹配,则尝试让右侧节点当前的匹配去找一个新的匹配,并且递归重复这一步骤,如果最终都能够匹配则返回 True,否则返回 False。这种算法只适用于计算无权二分图的最大匹配。

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().splitlines()
 5n1, n2, m = map(int, data[0].strip().split())
 6graph = defaultdict(list)
 7pair = [None] * (n2 + 1)
 8for i in range(1, m + 1):
 9    u, v = map(int, data[i].strip().split())
10    graph[u].append(v)
11
12
13def find(curr_node):
14    for next_node in graph[curr_node]:
15        if not status[next_node]:
16            status[next_node] = True
17            if pair[next_node] is None or find(pair[next_node]):
18                pair[next_node] = curr_node
19                return True
20    return False
21
22
23res = 0
24for node in range(1, n1 + 1):
25    status = [False] * (n2 + 1)
26    if find(node):
27        res += 1
28print(res)

系列文章