NTT算法和FFT极为相似,所不同之处只是在于选取的数域,对于FFT而言选取了复数域上的单位根来实现分治,而NTT选取的则是一类特殊的有限域上具有类似性质的原根来实现分治。
本文简述NTT原理并以C++实现NTT。
由于NTT与FFT基本类似,本文只介绍其不同点以及具体实现。
对于FFT可以参考如下博客:
本博客:FFT快速傅里叶变换的实现
_Orchidany的博客:FFT·快速傅里叶变换
本文的参考博客:
「算法笔记」快速数论变换(NTT)
快速数论变换(NTT)及蝴蝶操作构造详解
FFT虽然能以O(n\log n)
的复杂度计算多项式的点值表达法,但由于计算过程中涉及浮点数运算,导致计算存在精度问题,并且浮点数运算在一些没有浮点运算单元的机器上计算很慢。
于是我们考虑如果计算的多项式系数全为整数,能否利用一些特殊的有限域来实现计算,从而避开浮点数运算的缺点。
NTT计算时会受限于所取模数的大小,在需要的模数很大时,会使用不同模数分别计算NTT,最后再用中国剩余定理合并结果。
定义(原根):设m
是正整数,在有限域(\mathbb Z/m\mathbb Z)^*
中阶为\phi(m)
的数a
称为模m
的原根,其中\phi
为Euler函数。
原根的定义使得a^1,a^2,\cdots,a^{\phi(m)}
成为有限域中互不相同的元素,这与\mathbb C
中单位根的性质是极其相似的。
不过我们需要对多项式系数按照奇偶性进行二分实现分治,这使得我们需要通过原通常需要用到形如c\cdot2^k+1
的质数。如果我们有这样的质数M
,和原根g
,这时候如果有正整数N=2^l(l\leq k)
,我们有
\frac{c\cdot 2^k}{N}=c\cdot2^{k-l}
如果取a=g^{\frac{M-1}{N}}
,计算元素的阶得到\operatorname{ord} a = \phi(M)/\gcd(\frac{M-1}{N}, \phi(M))=c\cdot 2^{k}/c\cdot 2^{k-l}=2^l=N
,因而此时a^1, a^2, \cdots, a^N
是互不相同的元素,并且构成一个2的幂次阶循环群,这正是我们想要的。
利用Vandermonde矩阵性质,类似FFT那样,我们可以从NTT变换得到逆变换INTT变换,设x(n)
为整数序列,有如下结论
NTT: X(m) = \sum\limits_{i=0}^{N}x(n)a^{mn}\pmod M
.
INTT: X(m) = N^{-1}\sum\limits_{i=0}^{N}x(n)a^{-mn}\pmod M
.
这里N^{-1},a^{-mn} \pmod M
为模意义下的乘法逆元,具体求法可以参考博客:模意义下乘法逆元的算法实现
根据上述构造,我们可以看出NTT计算点值表示时要求多项式的系数数量少于2^k
,并且此时模数的选取很关键,通常取如下模数:
998244353=1192^23+1
1004535809=4792^21+1
他们的原根都是3。
有了这些准备,我们类似FFT中将多项式次数扩展为2的幂次,再进行分治就好了,注意对于幂次计算可以用快速幂。
以下为通过NTT计算大数乘法的C++代码:
#include <cmath>
#include <cstdio>
#include <cstring>
typedef long long int int64;
const int MAXN = 2500005;
const int64 p = 998244353;
int64 A[MAXN], B[MAXN];
char s1[MAXN], s2[MAXN];
int rev[MAXN], ans[MAXN];
int lim = 1, log2_lim = 0;
inline int64 pow(int64 b, int64 e)
{
int64 ans = 1;
while (e) {
if (e & 1) ans = ans * b % p;
b = b * b % p;
e >>= 1;
}
return ans;
}
inline void swap(int64 &c1, int64 &c2)
{
int64 t;
t = c1; c1 = c2; c2 = t;
}
void FFT(int64 *c, int flag)
{
for (int i = 0; i < lim; i++)
if (i < rev[i]) // 防止重复交换
swap(c[i], c[rev[i]]);
for (int j = 1; j < lim; j <<= 1) { // log2_lim-j 递归层数
int j2 = j << 1;
int64 w = pow(3, ((p - 1) / j2)); // 注意此处原根的幂次
if (flag == -1) w = pow(w, p - 2); // 通过Fermat小定理求逆
for (int k = 0; k < lim; k += j2) { // k 每层待合并的子问题
int64 t = 1;
for (int l = 0; l < j; l++, t = t * w % p) { // l 每个子问题下的数据偏移量
int kl = k + l;
int64 Nx = c[kl], Ny = t * c[kl + j] % p; // Nx为偶数项,Ny为奇数项
c[kl] = (Nx + Ny) % p; // 左半部分合并
c[kl + j] = (Nx - Ny) % p; // 右半部分合并
}
}
}
if (flag == -1) {
int64 inv = pow(lim, p - 2); // 通过Fermat小定理求逆
for (int i = 0; i <= lim; i++)
c[i] = c[i] * inv % p;
}
}
int main()
{
scanf("%s%s", s1, s2);
int len_A = strlen(s1), len_B = strlen(s2);
int len_A_1 = len_A - 1, len_B_1 = len_B - 1;
for (int i = len_A_1; i >= 0; i--) A[len_A_1 - i] = s1[i] ^ 0x30;
for (int i = len_B_1; i >= 0; i--) B[len_B_1 - i] = s2[i] ^ 0x30;
int len_AB = len_A + len_B;
while (lim < len_AB) lim <<= 1, log2_lim++;
for (int i = 1; i <= lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (log2_lim - 1)); // 初始化最终移动位置
FFT(A, 1); FFT(B, 1);
for (int i = 0; i <= lim; i++)
A[i] = A[i] * B[i] % p; // 计算卷积
FFT(A, -1);
for (int i = 0; i <= lim; i++) {
ans[i] += (A[i] + p) % p; // 由于之前蝴蝶变换出现减法,此处通过该变换保证数值为正
if (ans[i] >= 10) { // 处理十进制进位
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
}
lim++;
while (!ans[lim] && lim >= 1) lim--;
while (lim >= 0)
printf("%d", ans[lim--]);
return 0;
}