枚举、模拟与排序练习
文章目录
题目来自 AcWing
1210. 连号区间数 | 原题链接
暴力枚举的做法就是遍历 l
和 r
并对区间排序,判断区间内的数字是否连续递增。这样做的时间复杂度是 $O(n^3logn)$, 而 $1\leq N\leq10^4$ 的范围显然不支持这样的算法。经过观察发现,一段连续递增的数列中,其最大值减去最小值一定与数列长度相等。利用这个特点,在遍历 l
与 r
时,只需要维护区间内最大值和最小值,然后随时做判断,这样做的时间复杂度为 $O(n^2)$,在 $N$ 的数据范围内够用。
1n = int(input().strip())
2nums = tuple(map(int, input().strip().split()))
3
4count = 0
5for l in range(n):
6 max_val = nums[l]
7 min_val = nums[l]
8 for r in range(l, n):
9 if nums[r] > max_val:
10 max_val = nums[r]
11 elif nums[r] < min_val:
12 min_val = nums[r]
13
14 if max_val - min_val == r - l:
15 count += 1
16
17print(count)
1236. 递增三元组 | 原题链接
这道题还是针对枚举的优化。暴力枚举要 i
j
k
依次枚举 a
b
c
中的元素并判断组成的三元组是否是递增的,这样的时间复杂度为 $O(n^3)$ 远不能 AC。
假如先对三个数组排序,遍历 $A_i$ 的同时通过二分查找确定 $B_i$,遍历 b
中大于等于 $B_i$ 的元素,再通过二分查找找到 $C_i$,而且 c
中在 $C_i$ 后的元素也都可选。这样的时间复杂度为 $O(n^2logn)$,距离 AC 还有一段距离。
现在还要找出重复运算的部分。当我们计算完一次 $B_i$ 时,如果 $B_{i+1}=B_i$,可以直接使用 $B_i$ 的值而不需要重新计算一次。另外当 $A_i$ 确定时,通过二分查找找到的值为边界值,实际上 b
中所有大于等于它的值都可选择,也就是说一个 $B_i$ 可能会被不同的 $A_i$ 重复计算多次。因此考虑存储每一个 $B_i$ 对应可选的 $C_i$ 数量,并使用“后缀和”避免遍历 $B_i$。此时算法时间复杂度为 $O(nlogn)$,这样可以 AC。
1from bisect import bisect_right
2from itertools import accumulate
3
4n = int(input().strip())
5a = list(map(int, input().strip().split()))
6b = list(map(int, input().strip().split()))
7c = list(map(int, input().strip().split()))
8
9a.sort()
10b.sort()
11c.sort()
12
13b_count = [0] * n
14for i in range(n):
15 if i > 0 and b[i] == b[i - 1]:
16 b_count[i] = b_count[i - 1]
17 continue
18 c_idx = bisect_right(c, b[i])
19 b_count[i] = n - c_idx
20
21b_postfix_sum = tuple(accumulate([0] + b_count[::-1]))
22
23count = 0
24for i in range(n):
25 b_idx = bisect_right(b, a[i])
26 count += b_postfix_sum[n - b_idx]
27
28print(count)
这道题也有别的解法。可以先枚举 $B_i$ 随后在 a
中找到比 $B_i$ 小的数和在 c
中找到比 $B_i$ 大的数。找数这部分也有两种方法:二分查找和前缀和。二分查找比较好理解,这种做法的时间复杂度为 $O(nlogn)$。
1from collections import defaultdict
2from bisect import bisect_right, bisect_left
3
4n = int(input().strip())
5a = list(map(int, input().strip().split()))
6b = list(map(int, input().strip().split()))
7c = list(map(int, input().strip().split()))
8
9a.sort()
10c.sort()
11
12d = defaultdict(int)
13
14count = 0
15for i in range(n):
16 base = b[i]
17 if d.get(base) is not None:
18 count += d[base]
19 continue
20 left = bisect_left(a, base)
21 right = bisect_right(c, base)
22 count += left * (n - right)
23 d[base] = left * (n - right)
24print(count)
而使用前缀和则需要定义一个数组 cnt
,其中 cnt[i]
表示数组 a
中的 i
出现的次数。对 cnt
数组求前缀和,就可以快速求出 a
中小于 $B_i$ 的元素数量。这种方法不需要对 a
和 c
排序,但是遍历 cnt
的时间复杂度是 $A_i$ 的数据范围。本题中所有数据最大值均为 $10^5$,所以这种算法的时间复杂度为 $O(n)$。
1from collections import defaultdict
2from itertools import accumulate
3
4n = int(input().strip())
5a = list(map(int, input().strip().split()))
6b = list(map(int, input().strip().split()))
7c = list(map(int, input().strip().split()))
8
9d = defaultdict(int)
10cnt_a = [0] * (100001)
11cnt_c = [0] * (100001)
12for i in range(n):
13 cnt_a[a[i]] += 1
14 cnt_c[c[i]] += 1
15
16prefix_a = tuple(accumulate(cnt_a))
17prefix_c = tuple(accumulate(cnt_c))
18
19count = 0
20for base in b:
21 if d.get(base) is not None:
22 count += d[base]
23 continue
24 left = prefix_a[base - 1] if base >= 1 else 0
25 right = n - prefix_c[base] if base <= 100000 else 0
26 count += left * right
27 d[base] = left * right
28
29print(count)
1245. 特别数的和 | 原题链接
1n = int(input().strip())
2s = 0
3for i in range(n + 1):
4 if any(c in str(i) for c in "2019"):
5 s += i
6print(s)
就算 $N$ 到 $10^6$ 也不会超时。
1204. 错误票据 | 原题链接
输入最多有 100 行,用 sys.stdin.read()
显然是更好的选择。
1from collections import defaultdict
2import sys
3
4data = sys.stdin.read().strip().split("\n")
5n = int(data[0].strip())
6
7d = defaultdict(int)
8min_val, max_val = 100001, 0
9for dt in data[1:]:
10 line = map(int, dt.strip().split())
11 for l in line:
12 d[l] += 1
13 if l < min_val:
14 min_val = l
15 elif l > max_val:
16 max_val = l
17
18m, n = 0, 0
19for i in range(min_val, max_val + 1):
20 if d[i] == 0:
21 m = i
22 elif d[i] == 2:
23 n = i
24
25print(m, n)
466. 回文日期 | 原题链接
这道题真正想让你做的不是判断回文而是判断日期。
1start = int(input().strip())
2end = int(input().strip())
3
4
5def is_leap(year):
6 if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0:
7 return True
8 return False
9
10
11def is_valid_date(date):
12 year = date // 10000
13 month = (date % 10000) // 100
14 day = date % 100
15 if day < 1 or day > 31:
16 return False
17 if month < 1 or month > 12:
18 return False
19 if month in (4, 6, 9, 11) and day > 30:
20 return False
21 if month == 2:
22 if day > 29:
23 return False
24 if not is_leap(year) and day > 28:
25 return False
26 return True
27
28
29def next_valid_date(date):
30 year = date // 10000
31 month = (date % 10000) // 100
32 if month == 12:
33 year += 1
34 month = 1
35 return year * 10000 + month * 100 + 1
36 else:
37 month += 1
38 return year * 10000 + month * 100 + 1
39
40
41def is_palindrome(n):
42 return str(n) == str(n)[::-1]
43
44
45current = start
46count = 0
47while current <= end:
48 if is_valid_date(current):
49 if is_palindrome(current):
50 count += 1
51 current += 1
52 else:
53 current = next_valid_date(current)
54
55print(count)
Python 我们有 datetime
模块判断日期,datetime.datetime.strptime()
函数可以用来将一个字符串时间解析为 datetime.datetime
对象,具体用法是 strptime(date_string: str, format: str)
,format
是形如 %Y%m%d
的字符串。然而如果 date_str
并不是一个合法的时间字符串时,函数会抛出异常 ValueError
,因此需要配合 try-except
语句使用。另外 strptime()
的返回值是 datetime
对象,因此这个操作本身极耗时。
Python 也有直接遍历日期的方法。代码写出来大概是这样:
1from datetime import datetime, timedelta
2
3start_str = input().strip()
4end_str = input().strip()
5
6start = datetime.strptime(start_str, "%Y%m%d")
7end = datetime.strptime(end_str, "%Y%m%d")
8
9current = start
10count = 0
11while current <= end:
12 current.strftime("%Y%m%d")
13 current_str = current.strftime("%Y%m%d")
14 if current_str == current_str[::-1]:
15 count += 1
16 current += timedelta(days=1)
17
18print(count)
这里直接用 datetime.datetime
对象迭代,耗时更长。当然对于这道题,输入的范围无非就是 10000101 到 99991231(有一组极限数据确实是这俩),大不了直接打表。
787. 归并排序 | 原题链接
模板题。归并排序讲解
递归写法:
1_ = input()
2nums = list(map(int, input().strip().split()))
3
4
5def merge(left, right):
6 result = []
7 i = j = 0
8 while i < len(left) and j < len(right):
9 if left[i] <= right[j]:
10 result.append(left[i])
11 i += 1
12 else:
13 result.append(right[j])
14 j += 1
15
16 while i < len(left):
17 result.append(left[i])
18 i += 1
19 while j < len(right):
20 result.append(right[j])
21 j += 1
22 return result
23
24
25def merge_sort(nums):
26 if len(nums) <= 1:
27 return nums
28
29 mid = len(nums) // 2
30 left = merge_sort(nums[:mid])
31 right = merge_sort(nums[mid:])
32 return merge(left, right)
33
34
35print(*merge_sort(nums))
非递归写法:
1n = int(input().strip())
2nums = list(map(int, input().strip().split()))
3
4
5def merge(left, right):
6 result = []
7 i = j = 0
8 while i < len(left) and j < len(right):
9 if left[i] <= right[j]:
10 result.append(left[i])
11 i += 1
12 else:
13 result.append(right[j])
14 j += 1
15
16 while i < len(left):
17 result.append(left[i])
18 i += 1
19 while j < len(right):
20 result.append(right[j])
21 j += 1
22 return result
23
24
25def merge_sort(nums):
26 size = 1
27 while size < n:
28 for left in range(0, n, 2 * size):
29 mid = min(left + size, n)
30 right = min(left + 2 * size, n)
31 nums[left:right] = merge(nums[left:mid], nums[mid:right])
32 size *= 2
33 return nums
34
35
36print(*merge_sort(nums))
归并排序的时间复杂度为 $O(nlogn)$,空间复杂度为 $O(n)$,且归并排序是一种稳定的排序。非递归归并排序有可能会将左右子序列划分地极不均衡,导致效率降低,而递归归并排序的左右子序列长度会更平衡。
1219. 移动距离 | 原题链接
因为一个同余类是从 0 开始,而题目中的楼号是从 1 开始,因此处理前将楼号减去 1,然后再处理。
1w, m, n = map(int, input().strip().split())
2
3
4def get_pos(w, a):
5 x = (a - 1) // w
6 k = (a - 1) % (2 * w)
7 if k < w:
8 return x, k
9 else:
10 return x, 2 * w - k - 1
11
12
13mx, my = get_pos(w, m)
14nx, ny = get_pos(w, n)
15print(abs(mx - nx) + abs(my - ny))
1229. 日期问题 | 原题链接
1date_list = tuple(map(int, input().strip().split("/")))
2
3
4def is_leap(year):
5 return (year % 4 == 0 and year % 100 != 0) or year % 400 == 0
6
7
8def is_valid_date(date: str):
9 year, month, day = map(int, date.split("-"))
10 if year < 1960 or year > 2059:
11 return False
12 if month < 1 or month > 12:
13 return False
14 if day < 1 or day > 31:
15 return False
16 if month in (4, 6, 9, 11) and day > 30:
17 return False
18 if month == 2:
19 if day > 29:
20 return False
21 if not is_leap(year) and day > 28:
22 return False
23 return True
24
25
26def parseYMD(date_list: list):
27 year, month, day = date_list
28 if year < 60:
29 date = f"20{year:02d}-{month:02d}-{day:02d}"
30 else:
31 date = f"19{year:02d}-{month:02d}-{day:02d}"
32
33 if is_valid_date(date):
34 return date
35 else:
36 return None
37
38
39def parseMDY(date_list: list):
40 month, day, year = date_list
41 if year < 60:
42 date = f"20{year:02d}-{month:02d}-{day:02d}"
43 else:
44 date = f"19{year:02d}-{month:02d}-{day:02d}"
45
46 if is_valid_date(date):
47 return date
48 else:
49 return None
50
51
52def parseDMY(date_list: list):
53 day, month, year = date_list
54 if year < 60:
55 date = f"20{year:02d}-{month:02d}-{day:02d}"
56 else:
57 date = f"19{year:02d}-{month:02d}-{day:02d}"
58
59 if is_valid_date(date):
60 return date
61 else:
62 return None
63
64
65formats = set()
66ymd = parseYMD(date_list)
67if ymd:
68 formats.add(ymd)
69mdy = parseMDY(date_list)
70if mdy:
71 formats.add(mdy)
72dmy = parseDMY(date_list)
73if dmy:
74 formats.add(dmy)
75
76print(*sorted(list(formats), key=lambda x: x.split("-")), sep="\n")
因为数据量很小,还可以直接用 datetime
模块直接转换格式。
1from datetime import datetime
2
3date = input().strip()
4ans = set()
5
6try:
7 datetime_obj = datetime.strptime(date, "%y/%m/%d")
8 ans.add(datetime_obj)
9except ValueError:
10 pass
11
12try:
13 datetime_obj = datetime.strptime(date, "%m/%d/%y")
14 ans.add(datetime_obj)
15except ValueError:
16 pass
17
18try:
19 datetime_obj = datetime.strptime(date, "%d/%m/%y")
20 ans.add(datetime_obj)
21except ValueError:
22 pass
23
24print(*map(lambda x: x.strftime("%Y-%m-%d"), sorted(list(ans))), sep="\n")
datetime
对象里重载了比较运算符,可以直接用 sorted
函数对 datetime
列表排序。
1231. 航班时间 | 原题链接
从 A 地飞到 B 地,两地显示的时间差为实际飞行时间+时差,返程的则是实际飞行时间-时差,因此将两段飞行时间相加再除以 2 就是实际飞行时间。
犯懒了,直接用 datetime
实现。
1from datetime import datetime, timedelta
2
3n = int(input().strip())
4for _ in range(n):
5 range_a = input().strip().split()
6 start_a = datetime.strptime(range_a[0], "%H:%M:%S")
7 end_a = datetime.strptime(range_a[1], "%H:%M:%S")
8 if len(range_a) > 2:
9 day = int(range_a[2][2:-1])
10 end_a += timedelta(days=day)
11 duration_a = end_a - start_a
12
13 range_b = input().strip().split()
14 start_b = datetime.strptime(range_b[0], "%H:%M:%S")
15 end_b = datetime.strptime(range_b[1], "%H:%M:%S")
16 if len(range_b) > 2:
17 day = int(range_b[2][2:-1])
18 end_b += timedelta(days=day)
19 duration_b = end_b - start_b
20
21 time = datetime.min + (duration_a + duration_b) / 2
22 print(time.strftime("%H:%M:%S"))
两个 datetime
对象相减得到的是 timedelta
对象,timedelta
对象相加还是 timedelta
,但是 datetime
和 timedelta
相加是 datetime
。timedelta
没有 strftime
方法所以要先转换成 datetime
。datetime.min
是常量 0001-01-01 00:00:00
,加上这个对结果没影响。
在 Python 的 datetime
模块中,除了 %H
、%M
和 %S
用于表示小时、分钟和秒之外,还有许多其他的格式符可用于表示日期和时间。以下是一些常用的时间和日期格式符:
Python datetime
常用格式符:
-
年
%Y
:四位数字的年份(例如:2025)%y
:两位数字的年份(例如:25 表示 2025)
-
月份
%m
:两位数字的月份(01 到 12)%B
:完整的月份名称(例如:January)%b
:月份的缩写(例如:Jan)
-
日期
%d
:两位数字的日期(01 到 31)%j
:一年中的第几天(001 到 366)
-
星期
%A
:完整的星期几名称(例如:Monday)%a
:星期几的缩写(例如:Mon)%w
:星期中的数字(0 = Sunday, 6 = Saturday)
-
时间
%H
:24小时制的小时(00 到 23)%I
:12小时制的小时(01 到 12)%M
:分钟(00 到 59)%S
:秒(00 到 59)%p
:AM 或 PM 标记(大写)
-
其他
%c
:当地的日期和时间表示(例如:Mon Mar 2 15:34:20 2025)%x
:当地的日期表示(例如:03/02/25)%X
:当地的时间表示(例如:15:34:20)%%
:输出一个百分号(%
)
1241. 外卖店优先级 | 原题链接
最后一个测试点总是 TLE 死活过不了,我已经在尽力优化了。
1import sys
2from collections import defaultdict
3
4data = sys.stdin.read().strip().split("\n")
5n, m, t = map(int, data[0].strip().split())
6orders = defaultdict(list)
7for i in range(1, m + 1):
8 ti, d = map(int, data[i].strip().split())
9 orders[ti].append(d)
10restaurants = [0] * (n + 1)
11prior = [False] * (n + 1)
12
13last_time = 0
14for ts in sorted(orders.keys()):
15
16 for i in range(1, n + 1):
17 restaurants[i] -= ts - last_time - 1
18 if restaurants[i] < 0:
19 restaurants[i] = 0
20
21 increased = set()
22 for d in orders[ts]:
23 restaurants[d] += 2
24 increased.add(d)
25 if restaurants[d] > 5:
26 prior[d] = True
27
28 for i in range(1, n + 1):
29 if i not in increased:
30 restaurants[i] -= 1
31 if restaurants[i] <= 3:
32 prior[i] = False
33 if restaurants[i] < 0:
34 restaurants[i] = 0
35 last_time = ts
36
37rest_time = t - ts
38if rest_time > 0:
39 for i in range(1, n + 1):
40 if restaurants[i] <= 3 + rest_time:
41 prior[i] = False
42
43print(prior.count(True))
788. 逆序对的数量 | 原题链接
解法是归并排序的变体。至于为什么能够转化为归并排序,详细解释在这里。
1n = int(input().strip())
2nums = list(map(int, input().strip().split()))
3
4
5def merge_sort(nums):
6 if len(nums) <= 1:
7 return nums, 0
8
9 mid = len(nums) // 2
10 left, left_count = merge_sort(nums[:mid])
11 right, right_count = merge_sort(nums[mid:])
12
13 tmp = list()
14 i = j = 0
15 count = 0
16 while i < len(left) and j < len(right):
17 if left[i] <= right[j]:
18 tmp.append(left[i])
19 i += 1
20 else:
21 tmp.append(right[j])
22 j += 1
23 count += len(left) - i
24
25 tmp += left[i:]
26 tmp += right[j:]
27
28 return tmp, count + left_count + right_count
29
30
31print(merge_sort(nums)[1])