枚举、模拟与排序练习

文章目录

题目来自 AcWing


1210. 连号区间数 | 原题链接

暴力枚举的做法就是遍历 lr 并对区间排序,判断区间内的数字是否连续递增。这样做的时间复杂度是 $O(n^3logn)$, 而 $1\leq N\leq10^4$ 的范围显然不支持这样的算法。经过观察发现,一段连续递增的数列中,其最大值减去最小值一定与数列长度相等。利用这个特点,在遍历 lr 时,只需要维护区间内最大值和最小值,然后随时做判断,这样做的时间复杂度为 $O(n^2)$,在 $N$ 的数据范围内够用。

 1n = int(input().strip())
 2nums = tuple(map(int, input().strip().split()))
 3
 4count = 0
 5for l in range(n):
 6    max_val = nums[l]
 7    min_val = nums[l]
 8    for r in range(l, n):
 9        if nums[r] > max_val:
10            max_val = nums[r]
11        elif nums[r] < min_val:
12            min_val = nums[r]
13
14        if max_val - min_val == r - l:
15            count += 1
16
17print(count)

1236. 递增三元组 | 原题链接

这道题还是针对枚举的优化。暴力枚举要 i j k 依次枚举 a b c 中的元素并判断组成的三元组是否是递增的,这样的时间复杂度为 $O(n^3)$ 远不能 AC。

假如先对三个数组排序,遍历 $A_i$ 的同时通过二分查找确定 $B_i$,遍历 b 中大于等于 $B_i$ 的元素,再通过二分查找找到 $C_i$,而且 c 中在 $C_i$ 后的元素也都可选。这样的时间复杂度为 $O(n^2logn)$,距离 AC 还有一段距离。

现在还要找出重复运算的部分。当我们计算完一次 $B_i$ 时,如果 $B_{i+1}=B_i$,可以直接使用 $B_i$ 的值而不需要重新计算一次。另外当 $A_i$ 确定时,通过二分查找找到的值为边界值,实际上 b 中所有大于等于它的值都可选择,也就是说一个 $B_i$ 可能会被不同的 $A_i$ 重复计算多次。因此考虑存储每一个 $B_i$ 对应可选的 $C_i$ 数量,并使用“后缀和”避免遍历 $B_i$。此时算法时间复杂度为 $O(nlogn)$,这样可以 AC。

 1from bisect import bisect_right
 2from itertools import accumulate
 3
 4n = int(input().strip())
 5a = list(map(int, input().strip().split()))
 6b = list(map(int, input().strip().split()))
 7c = list(map(int, input().strip().split()))
 8
 9a.sort()
10b.sort()
11c.sort()
12
13b_count = [0] * n
14for i in range(n):
15    if i > 0 and b[i] == b[i - 1]:
16        b_count[i] = b_count[i - 1]
17        continue
18    c_idx = bisect_right(c, b[i])
19    b_count[i] = n - c_idx
20
21b_postfix_sum = tuple(accumulate([0] + b_count[::-1]))
22
23count = 0
24for i in range(n):
25    b_idx = bisect_right(b, a[i])
26    count += b_postfix_sum[n - b_idx]
27
28print(count)

这道题也有别的解法。可以先枚举 $B_i$ 随后在 a 中找到比 $B_i$ 小的数和在 c 中找到比 $B_i$ 大的数。找数这部分也有两种方法:二分查找和前缀和。二分查找比较好理解,这种做法的时间复杂度为 $O(nlogn)$。

 1from collections import defaultdict
 2from bisect import bisect_right, bisect_left
 3
 4n = int(input().strip())
 5a = list(map(int, input().strip().split()))
 6b = list(map(int, input().strip().split()))
 7c = list(map(int, input().strip().split()))
 8
 9a.sort()
10c.sort()
11
12d = defaultdict(int)
13
14count = 0
15for i in range(n):
16    base = b[i]
17    if d.get(base) is not None:
18        count += d[base]
19        continue
20    left = bisect_left(a, base)
21    right = bisect_right(c, base)
22    count += left * (n - right)
23    d[base] = left * (n - right)
24print(count)

而使用前缀和则需要定义一个数组 cnt,其中 cnt[i] 表示数组 a 中的 i 出现的次数。对 cnt 数组求前缀和,就可以快速求出 a 中小于 $B_i$ 的元素数量。这种方法不需要对 ac 排序,但是遍历 cnt 的时间复杂度是 $A_i$ 的数据范围。本题中所有数据最大值均为 $10^5$,所以这种算法的时间复杂度为 $O(n)$。

 1from collections import defaultdict
 2from itertools import accumulate
 3
 4n = int(input().strip())
 5a = list(map(int, input().strip().split()))
 6b = list(map(int, input().strip().split()))
 7c = list(map(int, input().strip().split()))
 8
 9d = defaultdict(int)
10cnt_a = [0] * (100001)
11cnt_c = [0] * (100001)
12for i in range(n):
13    cnt_a[a[i]] += 1
14    cnt_c[c[i]] += 1
15
16prefix_a = tuple(accumulate(cnt_a))
17prefix_c = tuple(accumulate(cnt_c))
18
19count = 0
20for base in b:
21    if d.get(base) is not None:
22        count += d[base]
23        continue
24    left = prefix_a[base - 1] if base >= 1 else 0
25    right = n - prefix_c[base] if base <= 100000 else 0
26    count += left * right
27    d[base] = left * right
28
29print(count)

1245. 特别数的和 | 原题链接

1n = int(input().strip())
2s = 0
3for i in range(n + 1):
4    if any(c in str(i) for c in "2019"):
5        s += i
6print(s)

就算 $N$ 到 $10^6$ 也不会超时。

1204. 错误票据 | 原题链接

输入最多有 100 行,用 sys.stdin.read() 显然是更好的选择。

 1from collections import defaultdict
 2import sys
 3
 4data = sys.stdin.read().strip().split("\n")
 5n = int(data[0].strip())
 6
 7d = defaultdict(int)
 8min_val, max_val = 100001, 0
 9for dt in data[1:]:
10    line = map(int, dt.strip().split())
11    for l in line:
12        d[l] += 1
13        if l < min_val:
14            min_val = l
15        elif l > max_val:
16            max_val = l
17
18m, n = 0, 0
19for i in range(min_val, max_val + 1):
20    if d[i] == 0:
21        m = i
22    elif d[i] == 2:
23        n = i
24
25print(m, n)

466. 回文日期 | 原题链接

这道题真正想让你做的不是判断回文而是判断日期。

 1start = int(input().strip())
 2end = int(input().strip())
 3
 4
 5def is_leap(year):
 6    if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0:
 7        return True
 8    return False
 9
10
11def is_valid_date(date):
12    year = date // 10000
13    month = (date % 10000) // 100
14    day = date % 100
15    if day < 1 or day > 31:
16        return False
17    if month < 1 or month > 12:
18        return False
19    if month in (4, 6, 9, 11) and day > 30:
20        return False
21    if month == 2:
22        if day > 29:
23            return False
24        if not is_leap(year) and day > 28:
25            return False
26    return True
27
28
29def next_valid_date(date):
30    year = date // 10000
31    month = (date % 10000) // 100
32    if month == 12:
33        year += 1
34        month = 1
35        return year * 10000 + month * 100 + 1
36    else:
37        month += 1
38        return year * 10000 + month * 100 + 1
39
40
41def is_palindrome(n):
42    return str(n) == str(n)[::-1]
43
44
45current = start
46count = 0
47while current <= end:
48    if is_valid_date(current):
49        if is_palindrome(current):
50            count += 1
51        current += 1
52    else:
53        current = next_valid_date(current)
54
55print(count)

Python 我们有 datetime 模块判断日期,datetime.datetime.strptime() 函数可以用来将一个字符串时间解析为 datetime.datetime 对象,具体用法是 strptime(date_string: str, format: str)format 是形如 %Y%m%d 的字符串。然而如果 date_str 并不是一个合法的时间字符串时,函数会抛出异常 ValueError,因此需要配合 try-except 语句使用。另外 strptime() 的返回值是 datetime 对象,因此这个操作本身极耗时。

Python 也有直接遍历日期的方法。代码写出来大概是这样:

 1from datetime import datetime, timedelta
 2
 3start_str = input().strip()
 4end_str = input().strip()
 5
 6start = datetime.strptime(start_str, "%Y%m%d")
 7end = datetime.strptime(end_str, "%Y%m%d")
 8
 9current = start
10count = 0
11while current <= end:
12    current.strftime("%Y%m%d")
13    current_str = current.strftime("%Y%m%d")
14    if current_str == current_str[::-1]:
15        count += 1
16    current += timedelta(days=1)
17
18print(count)

这里直接用 datetime.datetime 对象迭代,耗时更长。当然对于这道题,输入的范围无非就是 10000101 到 99991231(有一组极限数据确实是这俩),大不了直接打表。

787. 归并排序 | 原题链接

模板题。归并排序讲解

递归写法:

 1_ = input()
 2nums = list(map(int, input().strip().split()))
 3
 4
 5def merge(left, right):
 6    result = []
 7    i = j = 0
 8    while i < len(left) and j < len(right):
 9        if left[i] <= right[j]:
10            result.append(left[i])
11            i += 1
12        else:
13            result.append(right[j])
14            j += 1
15
16    while i < len(left):
17        result.append(left[i])
18        i += 1
19    while j < len(right):
20        result.append(right[j])
21        j += 1
22    return result
23
24
25def merge_sort(nums):
26    if len(nums) <= 1:
27        return nums
28
29    mid = len(nums) // 2
30    left = merge_sort(nums[:mid])
31    right = merge_sort(nums[mid:])
32    return merge(left, right)
33
34
35print(*merge_sort(nums))

非递归写法:

 1n = int(input().strip())
 2nums = list(map(int, input().strip().split()))
 3
 4
 5def merge(left, right):
 6    result = []
 7    i = j = 0
 8    while i < len(left) and j < len(right):
 9        if left[i] <= right[j]:
10            result.append(left[i])
11            i += 1
12        else:
13            result.append(right[j])
14            j += 1
15
16    while i < len(left):
17        result.append(left[i])
18        i += 1
19    while j < len(right):
20        result.append(right[j])
21        j += 1
22    return result
23
24
25def merge_sort(nums):
26    size = 1
27    while size < n:
28        for left in range(0, n, 2 * size):
29            mid = min(left + size, n)
30            right = min(left + 2 * size, n)
31            nums[left:right] = merge(nums[left:mid], nums[mid:right])
32        size *= 2
33    return nums
34
35
36print(*merge_sort(nums))

归并排序的时间复杂度为 $O(nlogn)$,空间复杂度为 $O(n)$,且归并排序是一种稳定的排序。非递归归并排序有可能会将左右子序列划分地极不均衡,导致效率降低,而递归归并排序的左右子序列长度会更平衡。

1219. 移动距离 | 原题链接

因为一个同余类是从 0 开始,而题目中的楼号是从 1 开始,因此处理前将楼号减去 1,然后再处理。

 1w, m, n = map(int, input().strip().split())
 2
 3
 4def get_pos(w, a):
 5    x = (a - 1) // w
 6    k = (a - 1) % (2 * w)
 7    if k < w:
 8        return x, k
 9    else:
10        return x, 2 * w - k - 1
11
12
13mx, my = get_pos(w, m)
14nx, ny = get_pos(w, n)
15print(abs(mx - nx) + abs(my - ny))

1229. 日期问题 | 原题链接

 1date_list = tuple(map(int, input().strip().split("/")))
 2
 3
 4def is_leap(year):
 5    return (year % 4 == 0 and year % 100 != 0) or year % 400 == 0
 6
 7
 8def is_valid_date(date: str):
 9    year, month, day = map(int, date.split("-"))
10    if year < 1960 or year > 2059:
11        return False
12    if month < 1 or month > 12:
13        return False
14    if day < 1 or day > 31:
15        return False
16    if month in (4, 6, 9, 11) and day > 30:
17        return False
18    if month == 2:
19        if day > 29:
20            return False
21        if not is_leap(year) and day > 28:
22            return False
23    return True
24
25
26def parseYMD(date_list: list):
27    year, month, day = date_list
28    if year < 60:
29        date = f"20{year:02d}-{month:02d}-{day:02d}"
30    else:
31        date = f"19{year:02d}-{month:02d}-{day:02d}"
32
33    if is_valid_date(date):
34        return date
35    else:
36        return None
37
38
39def parseMDY(date_list: list):
40    month, day, year = date_list
41    if year < 60:
42        date = f"20{year:02d}-{month:02d}-{day:02d}"
43    else:
44        date = f"19{year:02d}-{month:02d}-{day:02d}"
45
46    if is_valid_date(date):
47        return date
48    else:
49        return None
50
51
52def parseDMY(date_list: list):
53    day, month, year = date_list
54    if year < 60:
55        date = f"20{year:02d}-{month:02d}-{day:02d}"
56    else:
57        date = f"19{year:02d}-{month:02d}-{day:02d}"
58
59    if is_valid_date(date):
60        return date
61    else:
62        return None
63
64
65formats = set()
66ymd = parseYMD(date_list)
67if ymd:
68    formats.add(ymd)
69mdy = parseMDY(date_list)
70if mdy:
71    formats.add(mdy)
72dmy = parseDMY(date_list)
73if dmy:
74    formats.add(dmy)
75
76print(*sorted(list(formats), key=lambda x: x.split("-")), sep="\n")

因为数据量很小,还可以直接用 datetime 模块直接转换格式。

 1from datetime import datetime
 2
 3date = input().strip()
 4ans = set()
 5
 6try:
 7    datetime_obj = datetime.strptime(date, "%y/%m/%d")
 8    ans.add(datetime_obj)
 9except ValueError:
10    pass
11
12try:
13    datetime_obj = datetime.strptime(date, "%m/%d/%y")
14    ans.add(datetime_obj)
15except ValueError:
16    pass
17
18try:
19    datetime_obj = datetime.strptime(date, "%d/%m/%y")
20    ans.add(datetime_obj)
21except ValueError:
22    pass
23
24print(*map(lambda x: x.strftime("%Y-%m-%d"), sorted(list(ans))), sep="\n")

datetime 对象里重载了比较运算符,可以直接用 sorted 函数对 datetime 列表排序。

1231. 航班时间 | 原题链接

从 A 地飞到 B 地,两地显示的时间差为实际飞行时间+时差,返程的则是实际飞行时间-时差,因此将两段飞行时间相加再除以 2 就是实际飞行时间。

犯懒了,直接用 datetime 实现。

 1from datetime import datetime, timedelta
 2
 3n = int(input().strip())
 4for _ in range(n):
 5    range_a = input().strip().split()
 6    start_a = datetime.strptime(range_a[0], "%H:%M:%S")
 7    end_a = datetime.strptime(range_a[1], "%H:%M:%S")
 8    if len(range_a) > 2:
 9        day = int(range_a[2][2:-1])
10        end_a += timedelta(days=day)
11    duration_a = end_a - start_a
12
13    range_b = input().strip().split()
14    start_b = datetime.strptime(range_b[0], "%H:%M:%S")
15    end_b = datetime.strptime(range_b[1], "%H:%M:%S")
16    if len(range_b) > 2:
17        day = int(range_b[2][2:-1])
18        end_b += timedelta(days=day)
19    duration_b = end_b - start_b
20
21    time = datetime.min + (duration_a + duration_b) / 2
22    print(time.strftime("%H:%M:%S"))

两个 datetime 对象相减得到的是 timedelta 对象,timedelta 对象相加还是 timedelta,但是 datetimetimedelta 相加是 datetimetimedelta 没有 strftime 方法所以要先转换成 datetimedatetime.min 是常量 0001-01-01 00:00:00,加上这个对结果没影响。

在 Python 的 datetime 模块中,除了 %H%M%S 用于表示小时、分钟和秒之外,还有许多其他的格式符可用于表示日期和时间。以下是一些常用的时间和日期格式符:

Python datetime 常用格式符:

    • %Y:四位数字的年份(例如:2025)
    • %y:两位数字的年份(例如:25 表示 2025)
  • 月份

    • %m:两位数字的月份(01 到 12)
    • %B:完整的月份名称(例如:January)
    • %b:月份的缩写(例如:Jan)
  • 日期

    • %d:两位数字的日期(01 到 31)
    • %j:一年中的第几天(001 到 366)
  • 星期

    • %A:完整的星期几名称(例如:Monday)
    • %a:星期几的缩写(例如:Mon)
    • %w:星期中的数字(0 = Sunday, 6 = Saturday)
  • 时间

    • %H:24小时制的小时(00 到 23)
    • %I:12小时制的小时(01 到 12)
    • %M:分钟(00 到 59)
    • %S:秒(00 到 59)
    • %p:AM 或 PM 标记(大写)
  • 其他

    • %c:当地的日期和时间表示(例如:Mon Mar 2 15:34:20 2025)
    • %x:当地的日期表示(例如:03/02/25)
    • %X:当地的时间表示(例如:15:34:20)
    • %%:输出一个百分号(%

1241. 外卖店优先级 | 原题链接

最后一个测试点总是 TLE 死活过不了,我已经在尽力优化了。

 1import sys
 2from collections import defaultdict
 3
 4data = sys.stdin.read().strip().split("\n")
 5n, m, t = map(int, data[0].strip().split())
 6orders = defaultdict(list)
 7for i in range(1, m + 1):
 8    ti, d = map(int, data[i].strip().split())
 9    orders[ti].append(d)
10restaurants = [0] * (n + 1)
11prior = [False] * (n + 1)
12
13last_time = 0
14for ts in sorted(orders.keys()):
15
16    for i in range(1, n + 1):
17        restaurants[i] -= ts - last_time - 1
18        if restaurants[i] < 0:
19            restaurants[i] = 0
20
21    increased = set()
22    for d in orders[ts]:
23        restaurants[d] += 2
24        increased.add(d)
25        if restaurants[d] > 5:
26            prior[d] = True
27
28    for i in range(1, n + 1):
29        if i not in increased:
30            restaurants[i] -= 1
31        if restaurants[i] <= 3:
32            prior[i] = False
33        if restaurants[i] < 0:
34            restaurants[i] = 0
35    last_time = ts
36
37rest_time = t - ts
38if rest_time > 0:
39    for i in range(1, n + 1):
40        if restaurants[i] <= 3 + rest_time:
41            prior[i] = False
42
43print(prior.count(True))

788. 逆序对的数量 | 原题链接

解法是归并排序的变体。至于为什么能够转化为归并排序,详细解释在这里

 1n = int(input().strip())
 2nums = list(map(int, input().strip().split()))
 3
 4
 5def merge_sort(nums):
 6    if len(nums) <= 1:
 7        return nums, 0
 8
 9    mid = len(nums) // 2
10    left, left_count = merge_sort(nums[:mid])
11    right, right_count = merge_sort(nums[mid:])
12
13    tmp = list()
14    i = j = 0
15    count = 0
16    while i < len(left) and j < len(right):
17        if left[i] <= right[j]:
18            tmp.append(left[i])
19            i += 1
20        else:
21            tmp.append(right[j])
22            j += 1
23            count += len(left) - i
24
25    tmp += left[i:]
26    tmp += right[j:]
27
28    return tmp, count + left_count + right_count
29
30
31print(merge_sort(nums)[1])

相关系列文章