NTT快速数论变换及其实现

NTT快速数论变换及其实现

NTT算法和FFT极为相似,所不同之处只是在于选取的数域,对于FFT而言选取了复数域上的单位根来实现分治,而NTT选取的则是一类特殊的有限域上具有类似性质的原根来实现分治。

本文简述NTT原理并以C++实现NTT。

由于NTT与FFT基本类似,本文只介绍其不同点以及具体实现。
对于FFT可以参考如下博客:
本博客:FFT快速傅里叶变换的实现
_Orchidany的博客:FFT·快速傅里叶变换

本文的参考博客:
「算法笔记」快速数论变换(NTT)
快速数论变换(NTT)及蝴蝶操作构造详解

FFT虽然能以O(n\log n)的复杂度计算多项式的点值表达法,但由于计算过程中涉及浮点数运算,导致计算存在精度问题,并且浮点数运算在一些没有浮点运算单元的机器上计算很慢
于是我们考虑如果计算的多项式系数全为整数,能否利用一些特殊的有限域来实现计算,从而避开浮点数运算的缺点。

NTT计算时会受限于所取模数的大小,在需要的模数很大时,会使用不同模数分别计算NTT,最后再用中国剩余定理合并结果。

定义(原根):设m是正整数,在有限域(\mathbb Z/m\mathbb Z)^*中阶为\phi(m)的数a称为模m的原根,其中\phi为Euler函数。

原根的定义使得a^1,a^2,\cdots,a^{\phi(m)}成为有限域中互不相同的元素,这与\mathbb C中单位根的性质是极其相似的。

不过我们需要对多项式系数按照奇偶性进行二分实现分治,这使得我们需要通过原通常需要用到形如c\cdot2^k+1的质数。如果我们有这样的质数M,和原根g,这时候如果有正整数N=2^l(l\leq k),我们有

\frac{c\cdot 2^k}{N}=c\cdot2^{k-l}

如果取a=g^{\frac{M-1}{N}},计算元素的阶得到\operatorname{ord} a = \phi(M)/\gcd(\frac{M-1}{N}, \phi(M))=c\cdot 2^{k}/c\cdot 2^{k-l}=2^l=N,因而此时a^1, a^2, \cdots, a^N是互不相同的元素,并且构成一个2的幂次阶循环群,这正是我们想要的。

利用Vandermonde矩阵性质,类似FFT那样,我们可以从NTT变换得到逆变换INTT变换,设x(n)整数序列,有如下结论
NTT: X(m) = \sum\limits_{i=0}^{N}x(n)a^{mn}\pmod M.
INTT: X(m) = N^{-1}\sum\limits_{i=0}^{N}x(n)a^{-mn}\pmod M.
这里N^{-1},a^{-mn} \pmod M模意义下的乘法逆元,具体求法可以参考博客:模意义下乘法逆元的算法实现

根据上述构造,我们可以看出NTT计算点值表示时要求多项式的系数数量少于2^k,并且此时模数的选取很关键,通常取如下模数:
998244353=1192^23+1
1004535809=479
2^21+1
他们的原根都是3。

有了这些准备,我们类似FFT中将多项式次数扩展为2的幂次,再进行分治就好了,注意对于幂次计算可以用快速幂。

以下为通过NTT计算大数乘法的C++代码:

#include <cmath>
#include <cstdio>
#include <cstring>

typedef long long int int64;

const int MAXN = 2500005;
const int64 p = 998244353; 

int64 A[MAXN], B[MAXN];

char s1[MAXN], s2[MAXN];
int rev[MAXN], ans[MAXN];
int lim = 1, log2_lim = 0;

inline int64 pow(int64 b, int64 e)
{
    int64 ans = 1;
    while (e) {
        if (e & 1) ans = ans * b % p;
        b = b * b % p;
        e >>= 1;
    }
    return ans;
}

inline void swap(int64 &c1, int64 &c2)
{
    int64 t;
    t = c1; c1 = c2; c2 = t;
}

void FFT(int64 *c, int flag)
{
    for (int i = 0; i < lim; i++)
        if (i < rev[i])                 // 防止重复交换
            swap(c[i], c[rev[i]]);

    for (int j = 1; j < lim; j <<= 1) {                         // log2_lim-j 递归层数
        int j2 = j << 1;
        int64 w = pow(3, ((p - 1) / j2));                       // 注意此处原根的幂次
        if (flag == -1) w = pow(w, p - 2);                      // 通过Fermat小定理求逆
        for (int k = 0; k < lim; k += j2) {                     // k 每层待合并的子问题
            int64 t = 1;
            for (int l = 0; l < j; l++, t = t * w % p) {        // l 每个子问题下的数据偏移量
                int kl = k + l;
                int64 Nx = c[kl], Ny = t * c[kl + j] % p;       // Nx为偶数项,Ny为奇数项
                c[kl] = (Nx + Ny)  % p;                         // 左半部分合并
                c[kl + j] = (Nx - Ny) % p;                      // 右半部分合并
            }
        }
    }

    if (flag == -1) {
        int64 inv = pow(lim, p - 2);                            // 通过Fermat小定理求逆
        for (int i = 0; i <= lim; i++)
            c[i] = c[i] * inv % p;
    }
}

int main()
{
    scanf("%s%s", s1, s2);
    int len_A = strlen(s1), len_B = strlen(s2);

    int len_A_1 = len_A - 1, len_B_1 = len_B - 1;
    for (int i = len_A_1; i >= 0; i--) A[len_A_1 - i] = s1[i] ^ 0x30;
    for (int i = len_B_1; i >= 0; i--) B[len_B_1 - i] = s2[i] ^ 0x30;

    int len_AB = len_A + len_B;
    while (lim < len_AB) lim <<= 1, log2_lim++;

    for (int i = 1; i <= lim; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (log2_lim - 1));          // 初始化最终移动位置

    FFT(A, 1); FFT(B, 1);
    for (int i = 0; i <= lim; i++)
        A[i] = A[i] * B[i] % p;                 // 计算卷积
    FFT(A, -1);
    for (int i = 0; i <= lim; i++) {
        ans[i] += (A[i] + p) % p;               // 由于之前蝴蝶变换出现减法,此处通过该变换保证数值为正
        if (ans[i] >= 10) {                     // 处理十进制进位
            ans[i + 1] += ans[i] / 10;
            ans[i] %= 10;
        }
    }

    lim++;
    while (!ans[lim] && lim >= 1) lim--;
    while (lim >= 0)
        printf("%d", ans[lim--]);
    return 0;
}

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注