用 FFT 和 NTT 解决多项式乘法
多项式乘法是一个很常见重要的问题。给两个多项式A(x)∑i0n−1aixi,B(x)∑j0m−1bjxjA(x)i0∑n−1aixi,B(x)j0∑m−1bjxj那么它们的乘积 C(x)A(x)B(x)C(x)A(x)B(x) 的第 kk 项系数是ck∑ijkaibjckijk∑aibj这就是序列 aa 和 bb 的卷积。直接按定义算需要枚举所有 i,ji,j复杂度是 O(nm)O(nm)。假定 n,mn,m 规模相当这个复杂度就是平方的很快就不够用了。大整数乘法也是同一个问题。把一个整数按进制 ββ 拆成若干位X∑iaiβi,Y∑jbjβjXi∑aiβi,Yj∑bjβj那么 XYXY 在处理进位之前第 kk 位的原始值就是 ∑ijkaibj∑ijkaibj。也就是说多项式乘法、卷积、大整数乘法在核心计算上是同一个结构只是最后解释结果的方式不同多项式保留系数卷积保留序列大整数还要做一遍进位。系数表示和点值表示一个次数小于 nn 的多项式可以用它的 nn 个系数表示A(x)a0a1x⋯an−1xn−1A(x)a0a1x⋯an−1xn−1这叫系数表示。它很适合做加法因为对应系数相加即可但相乘两个多项式会变成卷积比较麻烦。另一种表示方式是点值表示。选 nn 个不同的点 x0,x1,…,xn−1x0,x1,…,xn−1记录(xi,A(xi))(xi,A(xi))这些点值也可以唯一确定多项式。用点值表示做乘法非常轻松如果 C(x)A(x)B(x)C(x)A(x)B(x)那么C(xi)A(xi)B(xi)C(xi)A(xi)B(xi)所以只要逐点相乘就行。因此要避开平方的卷积就要把系数表示转化成点值表示进行点乘运算然后把点值表示还原成系数表示。那么如何在系数表示和点值表示之间快速转换FFT 快速傅里叶变换朴素求值要对每个点代入一次多项式求一次是 O(n)O(n)一共 nn 个点就是 O(n2)O(n2)。而 FFT 的关键是选择一组非常特殊的点使求值在 O(nlogn)O(nlogn) 内完成。FFT 使用复数域里的单位根。长度为 nn 时取ωe2πi/nωe2πi/n取用它是因为他有一个性质 ωn1ωn1并且 ω0,ω1,…,ωn−1ω0,ω1,…,ωn−1 两两不同。对一个系数序列 a0,…,an−1a0,…,an−1它的离散傅里叶变换可以写成a^k∑j0n−1ajωjka^kj0∑n−1ajωjk从信号与系统里傅里叶变换的意义来说把 a0,a1...an−1a0,a1...an−1 转化成 a^0,a^1...a^n−1a^0,a^1...a^n−1 是从时域到频域的转换而从多项式的角度来说这其实就是把多项式 A(x)a0a1x⋯an−1xn−1A(x)a0a1x⋯an−1xn−1 分别求解 A(1),A(ω),...A(ωn−1)A(1),A(ω),...A(ωn−1)。相比随便选 nn 个点求解来说求解这些特殊点的值可以利用其数学性质加速。蝶形变换单位根有很强的对称性。假设 nn 是 2 的整次幂把多项式按偶数次和奇数次拆开A(x)A0(x2)xA1(x2)A(x)A0(x2)xA1(x2)其中A0(x)a0a2xa4x2⋯A0(x)a0a2xa4x2⋯A1(x)a1a3xa5x2⋯A1(x)a1a3xa5x2⋯因此A(ωk)A0(ω2k)ωkA1(ω2k)A(ωk)A0(ω2k)ωkA1(ω2k)A(ωkn/2)A0(ω2k)−ωkA1(ω2k)A(ωkn/2)A0(ω2k)−ωkA1(ω2k)而 ω2ω2 也是一个单位根所以可以递归求解。设 uA0(ω2k),vA1(ω2kuA0(ω2k),vA1(ω2k那么原多项式在两个对应点上的值是A(ωk)uωkvA(ωk)uωkvA(ωkn/2)u−ωkvA(ωkn/2)u−ωkv这就是蝶形合并。它说明要算 A(1),A(ω),A(ω2),…,A(ωn−1)A(1),A(ω),A(ω2),…,A(ωn−1)可以先分别算偶数部分和奇数部分在 1,ω2,ω4,…,ωn−21,ω2,ω4,…,ωn−2 上的值再用加减和乘以 ωkωk 合并。一次拆分把长度为 nn 的问题变成两个长度为 n/2n/2 的问题每层合并花 O(n)O(n)一共有 lognlogn 层所以整体是 O(nlogn)O(nlogn)。在迭代版 FFT 中通常先做 bit-reversal 重排然后从长度 22 的小块开始合并块长依次翻倍。每一层会遍历所有块每个块里做若干个蝶形变换。逆变换正变换把系数变成点值逆变换把点值插值回系数。DFT 的逆变换形式是aj1n∑k0n−1a^kω−jkajn1k0∑n−1a^kω−jk也就是说逆变换和正变换几乎一样只是把 ωω 换成 ω−1ω−1最后所有结果再除以 nn。这个形式在 NTT 里也会原样保留只是“除以 nn”会变成乘上 nn 在模意义下的逆元。用 FFT 做卷积的流程很短先把两个序列补零到长度 NN其中 NN 至少是 nm−1nm−1通常取不小于它的最小二次幂。然后分别 FFT逐点相乘再逆 FFT。最终前 nm−1nm−1 项就是卷积结果。从 FFT 到 NTT 数论变换FFT 用复数速度快但有浮点误差。对于整数卷积尤其是竞赛、密码学或需要完全精确的场景更常用 NTT。NTT 可以看作“把 FFT 搬到模质数意义下”。现在在模 pp 的有限域里如果存在一个元素 gg它的幂能生成所有非零元素那么 gg 是模 pp 的原根。换句话说1,g,g2... gp−21,g,g2... gp−2刚好遍历了 [1,p)[1,p)而 gp−1gp−1 再次回到 11。比如 33 是模 77 的一个原根1→3→2→6→4→5→11→3→2→6→4→5→1但 22 不是。由于模质数 pp 下的非零元素有 p−1p−1 个我们想做长度为 nn 的 NTT就需要 nn 整除 p−1p−1。换一个稍大的例子。取 p17p17可以验证 g3g3 是一个原根。现在想做长度为 88 的 NTT可以构造ω3(17−1)/8≡9(mod17)ω3(17−1)/8≡9(mod17)验证一下 98≡1(mod17)98≡1(mod17)并且在 1,2,…,71,2,…,7 次幂时都不会提前变成 11。这时我们就发现1,9,92,…,971,9,92,…,97 就可以扮演 FFT 里 1,ω,ω2,…,ω71,ω,ω2,…,ω7 的角色。同理ω3ω3 是一个长度 16 的 NTT 单位根ω13ω13 是一个长度 4 的 NTT 单位根。更普遍的来说若 gg 是模 pp 的原根则ωg(p−1)/nmodpωg(p−1)/nmodp就是一个 nn 阶单位根。这样FFT 里关于单位根的推导仍然成立只是所有加法、减法、乘法都放在模 pp 下。复数单位根 ωe2πi/nωe2πi/n 变成模意义下的单位根 ωg(p−1)/nωg(p−1)/n在逆变换里原理也完全相同除以 nn 变成乘 n−1modpn−1modp乘法逆元常用模数 998244353MOD998244353MOD998244353 很常用首先他是质数并且998244353119⋅2231998244353119⋅2231它足够大最多可以生成长度 223223 的 NTT 单位根。对多数算法题来说这个长度已经足够大。同时这个数的大小刚好满足 2MOD2MOD 不超过 int32 的上限 MOD2MOD2 没有超过 int64 的上限利于实际实现。33 是他的一个原根。通用情况那么更大的情况呢和 FFT 的区别是NTT 下卷积系数的运算也在模 pp 意义下进行若系数最终没有超过 pp 就好说但如果超过就会在取模意义下绕回来我们这时无法得知算出的 cici 实际上是 cici 还是 cipcip。假设输入系数非负长度为 LL每个系数不超过 MM那么卷积中单个系数的上界大约是 LM2LM2。如果这个值可能超过模数就要么选择更大的可用模数要么使用多个 NTT 质数分别计算一次再通过中国剩余定理合并。具体来说选几个互质模数 p1,p2,…,prp1,p2,…,pr分别算出ckmodp1,ckmodp2,…,ckmodprckmodp1,ckmodp2,…,ckmodpr只要真实的 ckck 小于 Pp1p2⋯prPp1p2⋯pr那么用 CRT 就可以唯一恢复 ckck。实际工程里常见做法是使用多个 NTT-friendly prime例如 998244353998244353、10045358091004535809、469762049469762049 。对于大整数乘法还可以通过控制进制来降低单个卷积系数的上界。例如把十进制字符串拆成 103103 或 104104 进制而不是直接用 109109 进制。这个情况下进制越大卷积长度越短但中间系数越容易溢出模数进制越小长度更长但结果更安全。这是一个很实际的 trade-off。#include bits/stdc.husing namespace std;const int MOD 998244353;const int G 3;long long mod_pow(long long a, long long e) {long long r 1;while (e) {if (e 1) r r * a % MOD;a a * a % MOD;e 1;}return r;}void ntt(vectorint a, bool invert) {int n (int)a.size();for (int i 1, j 0; i n; i) {int bit n 1;for (; j bit; bit 1) j ^ bit;j ^ bit;if (i j) swap(a[i], a[j]);}for (int len 2; len n; len 1) {int wlen (int)mod_pow(G, (MOD - 1) / len);if (invert) wlen (int)mod_pow(wlen, MOD - 2);for (int i 0; i n; i len) {long long w 1;int half len 1;for (int j 0; j half; j) {int u a[i j];int v (int)(a[i j half] * w % MOD);int x u v;if (x MOD) x - MOD;int y u - v;if (y 0) y MOD;a[i j] x;a[i j half] y;w w * wlen % MOD;}}}