074. 数字阶乘链(Digit factorial chains)

我们都知道数字145各位数的阶乘之和等于这个数本身,即:

$$ 1!+4!+5!=1+24+120=145 $$
另外一个数169可能大家知道的不多,它可以产生一个返回到自身的最长链条,事实上总共只有三个数有这样的性质:
$$ \begin{aligned} &169 → 363601 → 1454 → 169\\ &871 → 45361 → 871\\ &872 → 45362 → 872\\ \end{aligned} $$
不难证明每个超始数最终都会出现循环,比如:
$$ \begin{aligned} &69 → 363600 → 1454 → 169 → 363601 (→ 1454)\\ &78 → 45360 → 871 → 45361 (→ 871)\\ &540 → 145 (→ 145)\\ \end{aligned} $$
可以看到,从69开始可以产生包含五个不重复元素的链条,但是在一百万以下的超始数中,最长的不重复链条包含六十个元素。求在一百万以下的超始数中,有多少个链条恰好包含六十个不重复的元素?

分析:这道题的思路相对比较直接,最重要的优化技巧是要把已经计算过链条长度的数字缓存下来,从而加快计算之后数字链条长度的速度。题目中已经给出了若干数字,它们的链条长度是已知的,我们可以将其保存一个字典,字典的键为数字,对应的值为链条长度,例如169, 363601, 1454的链条长度都为三。

对每一个数字\(N\),我们建立一个数组\(arr\)保存其阶乘链条中的数。如果一个数\(a\)在上面的字典中已经出现了,则我们不需要再计算下去,直接用数组的长度加上数字\(a\)的链条长度即可。否则,我们检查数\(a\)是否在数组\(arr\)中已经出现过,如果已经出现过,则也可以停止计算,返回数组\(arr\)的长度即为数字\(N\)的的链条长度。如果以上两个条件都不满足,则我们计算数字\(a\)的各位数阶乘之和\(b\),并把\(b\)添加到数组\(arr\)中,继续循环。最后,我们在字典中添加数字\(N\)的链条长度。另外一个优化技巧是,为了让我们在计算不同数字的链条长度可以共享同一个缓存字典,我们把这个字典设为全局变量,并在计算链条长度的函数中使用global关键字,使其可以修改字典这个全局变量,这样可以大大提高算法的效率。

实际上,这道题还有一个更高效的算法。从题意我们容易得知,如果一个数的链条长度为六十,那么这个数的全排列也应该是六十,比如经过尝试我们得知1479的链条长度是六十,则4179, 1749, 1974等等数字的链条长度也是六十。因此,我们考虑从零至九这十个数字中,可放回的抽取一个数、两个数、三个数以至六个数,这样可以构成5004个组合。然后我们遍历这5004个组合构成的数字,只要一个数字的链条长度是六十,则其所有全排列的链条长度也是六十,需要注意的是,零和一的阶乘都是一,所以计算全排列的个数会更加复杂,需要用到多项式系数,并在它的基础上进行调整。最后我们把满足条件的数字对应的全排列的个数加起来,即为题目所求。事实上,在一百万以内,最小的满足条件的四位数是1479,它的全排列共有\(4!=24\)个,考虑到零和一的阶乘相同,则4079也是满足条件的数,因为零不能做首位数字,则其对应的全排列个数为\(3\times3!=18\)个。最小的满足条件的六位数是223479,其中数字二出现了两次,则其对应的全排列为\(6!/2=360\)个,则一百万以下链条长度为六十的数字共有\(24+18+360=402\)个。这种思路的实现代码更加复杂,考虑到第一个思路的代码效率已经比较高,所以我只实现了第一个思路,感兴趣的同学可以自己实现一下第二个思路。

第一个思路的代码如下:

# time cost = 487 ms ± 2.44 ms

DT = {169:3,363601:3,1454:3,871:2,45361:2,872:2,45362:2,69:5,78:4,540:2}

def digit_fact_sum(x):
    fac = [1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880]
    res = sum([fac[int(i)] for i in str(x)])
    return res

def chain_length(n):
    arr,m = [],n
    global DT
    while True:
        if n in DT:
            length = len(arr) + DT[n]
            break
        elif n in arr:
            length = len(arr)
            break
        else:
            arr.append(n)
            n = digit_fact_sum(n)
    DT[m] = length
    return length

def main(N=10**6):
    c = 0
    for i in range(1,N):
        if chain_length(i) == 60:
            c += 1
    return c