没错就是这张套路图,据说是 segment fault 的招聘?无聊刷一下题,抛砖引玉。(知乎那边也回了贴,话说破乎的编辑器还真不适合贴代码……)


正好最近在出招聘测试的笔试题,在网上看到跟第 1 题类似的,数字比这个大 10 倍有多。自己做了一下发现如果之前没刷过类似的,要在限时内在纸上写出优化过的代码并不容易,就放弃把它收进题库了……

现成的贴上来:

import time
import math


def is_prime(x):
    if x <= 1:
        return False

    if x == 2:
        return True

    for i in range(3, int(math.sqrt(x)) + 1, 2):
        if x % i == 0:
            return False

    return True


def prime_factorization(x):
    for i in range(2, int(math.sqrt(x)) + 1):
        if x % i == 0 and is_prime(i):
            quotient = x // i

            if is_prime(quotient):
                print(i, quotient)


if __name__ == '__main__':
    t1 = time.time()
    prime_factorization(707829217)
    t2 = time.time()

    print(f'ms: {t2 * 1000 - t1 * 1000}')

输出:

8171 86627
ms: 3.166015625

我能想到的优化点有 3 个:


第二题:

没有用典型的 DP 模板(因为我想了半天都套不上……),题目条件太简单,用了取巧的办法硬是磨出来了,目测反而会比 DFS 快不少。

def get_answer(n):
    """
    这题有个取巧的方法,十位以上是 3 的时候,奇数的个数是 10^i / 2。例如 300-400 之间有 10^2 / 2 = 50 个奇数,如此类推。
    用记忆化搜索的思路把结果缓存起来,然后注意一下最后每位数字的限制,就能直接推出最终结果。

    速度极快,哪怕是 333333333333333333333333333333333333333 这种大数字也就 1ms 左右。

    题中 866278171 这个数字其实不是好例子,因为所有位都不是 3。
    下面注释那几处地方假如没注意,就连 N = 3 都出 bug 的情况下这个数也能得到正确结果……

    测试:1 3 4 9 10 13 29 31 33 39 41 43 100 299 301 303 333 399 400 1000
        1224 1234 1244 10000 33333 1234567 12345678
    """
    count = 0
    nums = []
    results = []

    x = n
    while x > 0:
        # nums 存储 N 的各位数字,个位在前,最高位在最后
        nums.append(x % 10)
        # results 是二维数组,行数为 N 的位数,11 列分别代表 0-9 开头的数字 和总计
        # 把每一位各数字开头的 3 的总数存起来
        results.append([0] * 11)
        x //= 10

    digits = len(nums)

    # 个位特殊处理
    results[0][3] = 1
    results[0][-1] = 1
    if nums[0] >= 3:  # 注意这里
        count += 1

    for i in range(1, digits):
        for j in range(9 + 1):
            results[i][j] += results[i - 1][-1]

            if j == 3:
                results[i][j] += 10 ** i // 2

            results[i][-1] += results[i][j]

            # 注意加的时候只能小于 N 在这一位的数字,否则会超出。
            # 如果这一位为 0 就跳过(在下一位或下几位会包含进去)。
            if j < nums[i]:
                count += results[i][j]

        # 补上 N 十位以上为 3 开头时漏统计的情况(因为不能读缓存)
        if nums[i] == 3:
            count += (sum(nums[j] * 10 ** j for j in range(i)) + 1) // 2

    return count

输出:(没错就是这么快)

441684627
ms: 0.07421875

上面的方法不太好验证正确性,可以配合穷举和 1234567 之类较小的数字做测试:

def brute_force(n):
    """
    如果是现场做笔试题,想不出 DP 至少写上这个凑数……
    效率很低,算个 12345678 都要 7 秒左右,据说算完题中的数要 7-10 分钟。
    """
    count = 0

    i = 1
    while i <= n:
        x = i
        while x > 0:
            if x % 10 == 3:
                count += 1
            x //= 10

        i += 2

    return count


↙↙↙阅读原文可查看相关链接,并与作者交流