真的只是入门。
可能会接着更 分治FFT 或者 任意模数NTT?
前置知识
复数 也可以参考高中数学课本,这里只会介绍 fft 需要的(默认已经入门复数)。
多项式的相关概念。
点值表示法:假设 \(f(x)\) 是一个 \(n-1\) 次多项式,那么将 \(n\) 个 不同的 \(x\) 代入,可以得到 \(n\) 个 \(y\)。这 \(n\) 个点对 \((x,y)\) 唯一确定了该多项式。那么就可以通过多项式求出其点值表示,也可以反过来。
注意:以下 \(n\) 均为 \(2\) 的整数次幂,若不足则补零(显然不会影响结果)
FFT 简介
用来加速多项式乘法的一个东西。而普通多项式乘法是 \(O(n^2)\) 的。但是如果是点值表示法,则是 \(O(n)\) 的(比如 \(c(x) = a(x)\times b(x)\),那么只需要枚举 \(2\times n\) 个不同的 \(x\) 即可求出 \(c\) 的点值表示)。
考虑如何快速将两个多项式 \(a(x),b(x)\) 转成点值表示,再将一个多项式 \(c(x)\) 从点值表示转化回来即可。
这里就用到 FFT 了。
离散傅里叶变换
其实就是朴素版 FFT。
就是上面代入的 \(n\) 个 \(x\) 为 \(n\) 个复数。但这 \(n\) 个复数不是随便找的,而是 \(n\) 次单位根。
简单介绍一下,就是 \(x^n=1\) 在复数意义下所有的根。显然这玩意有 \(n\) 个。将这几个根按幅角从小到大排序,从 \(0\) 开始编号,第 \(i\) 个 \(n\) 次单位根记为 \(w_n^i\)。没有特殊声明时,一般特指第一个单位根,简记为 \(w_n\)。
可以算出 \(w_n^k=\cos(\frac{2k\pi}{n})+i\sin(\frac{2k\pi}{n})\)。
有两个性质:
- \(w_{2n}^{2k}=w_n^k\)。画个图理解一下。
- \(w_{n}^{k+\frac{n}{2}}-w_{n}^{k}\),画图发现它们关于原点对称。
画个图出来,大概这样子。
好看是好看, 但是为什么非要选择这 \(n\) 个点呢?
有个结论,将 \(a(x)\) 的离散傅里叶变换的结果作为 \(b(x)\) 的系数,将单位根的倒数 \(w_n^0,w_n^{-1},w_n^{-2},\cdots,w_n^{-n+1}\) 代入以后,得到的每个数再除以 \(n\),就是 \(a(x)\) 的各项系数。
证明
设 \((b_0,b_1,\cdots,b_{n-1})\) 为 \(A(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^{n-1}\) 离散傅里叶变换的结果。
设 \(B(x) = b_0+b_1x+\cdots+b_nx^{n-1}\),然后将那几个单位根的倒数代入得到一个新的离散傅里叶变换结果 \((c_0,c_1,\cdots,c_{n-1})\)。
有
发现当且仅当 \(j=k\) 时,后面式子的值为 \(n\),反之,后面的式子值为 \(0\)。
那么就有了 \(a_i=\frac{c_i}{n}\)。
然后这样就是 \(O(n^2)\) 的了,没啥用啊。
这是因为傅里叶爷爷 (1768年3月21日~1830年5月16日) 没有见过计算机(世界上第一台通用计算机“ENIAC”于1946年2月14日在美国宾夕法尼亚大学诞生),所以他不需要考虑时间复杂度,但是后人就要优化了。
考虑一个多项式 \(A(x)=a_0+a_1x+a_2x^2+\cdots+a^{n-1}x^{n-1}\) 要求离散傅里叶变换,将一个 \(w_n\) 代入。
将 \(A(x)\) 的每一项按照下标奇偶性分组,设 \(A_1(x)=a_0+a_2x+\cdots+a^{n-2}x^{\frac{n}{2}-1},A_2(x) = a_1 + a_3x + \cdots + a_{n - 1}x^{\frac{n}{2} - 1}\)。
显然有 \(A(x) = A_1(x^2)+xA_2(x^2)\)。
将 \(w_n^k(k<\frac{n}{2})\) 代入,有 \(A(w_n^k)=A_1(w_n^{2k})+w_n^kA_2(w_n^{2k})=A_1(w_{\frac{n}{2}}^k)+w_{n}^kA_2(w_{\frac{n}{2}}^k)\)。
那么将 \(w_n^{k+\frac{n}{2}}\) 代入,最后得到一个 \(A_1(w_n^{2k})-w_n^kA_2(w_n^{2k})\)。
发现这两个式子只有一个常数项不同,那么求第一个式子可以顺便将第二个式子求出来,因为第一个式子取遍 \([0,\frac{n}{2}-1]\),第二个式子取遍 \([\frac{n}{2},n]\),所以将原问题缩小了一半,而且缩小后的问题符合原问题性质,然后就这么递归分治下去即可。
时间复杂度 \(O(n\log n)\)。
【模板】多项式乘法(FFT)
递归实现,跑的好像不是很慢。
个人感觉递归实现便于理解,可以先打递归实现试试水。
code
#include<bits/stdc++.h>
using namespace std;
#define rep(i,s,t,p) for(int i = s;i <= t;i += p)
#define drep(i,s,t,p) for(int i = s;i >= t;i -= p)
#ifdef LOCALauto I = freopen("in.in","r",stdin),O = freopen("out.out","w",stdout);
#elseauto I = stdin,O = stdout;
#endif
using ll = long long;using ull = unsigned long long;
using db = double;using ldb = long double;
using comp = complex<db>;
const db pi = acos(-1);
const int N = 4e6 + 10;
int n,m;comp a[N],b[N];
void fft(comp *a,int n,int inv){if(n == 1) return;//最后了,直接 return 就好啦。int m = n>>1;comp a1[m+5],a2[m+5];rep(i,0,n,2) a1[i>>1] = a[i],a2[i>>1] = a[i+1];//奇偶分组fft(a1,m,inv);fft(a2,m,inv);comp W = comp(cos(2.0*pi/n),inv*sin(2.0*pi/n)),w = comp(1,0);//单位根。for(int i = 0;i < m;++i,w = w*W){a[i] = a1[i] + w*a2[i];//求A1。a[i+m] = a1[i] - w*a2[i];//求A2。}
}
signed main(){cin.tie(nullptr)->sync_with_stdio(false);cin>>n>>m;rep(i,0,n,1) cin>>a[i];rep(i,0,m,1) cin>>b[i];int lim = 1;while(lim <= n+m) lim <<= 1;//凑成 2^nfft(a,lim,1);fft(b,lim,1);//转成点值表示rep(i,0,lim,1) a[i] = a[i]*b[i];//点值乘fft(a,lim,-1);//转成系数表示rep(i,0,n+m,1) cout<<(int)(a[i].real()/lim+0.5)<<' ';
}
但是相比于常写的迭代写法来说,这种写法还是比较慢。(在洛谷这道题,最后一个数据点,递归跑了 1000ms,迭代跑了 600多ms)
那么如何写成迭代法呢?
考虑最后进行操作的序列,发现其实最后操作的序列就是将下标二进制位翻转排序。
手动模拟一下应该很好理解,就是先以二进制位下第零位奇偶分组,然后每组中以第一位奇偶分组,以此类推。
那么就可以知道每个数最后应该在的位置,这个可以递推出来,假如当前位置为 \(i\),那么它会在 (to[i>>1]>>1)|((i&1)<<(ct-1))
,其中 \(to_i\) 表示第 \(i\) 个数应该在的位置,显然有 i=to[to[i]]
,\(ct\) 是二进制位数,即 \(n=2^{ct}\)。
然后就可以递推出结果了。
code
#include<bits/stdc++.h>
using namespace std;
#define rep(i,s,t,p) for(int i = s;i <= t;i += p)
#define drep(i,s,t,p) for(int i = s;i >= t;i -= p)
#ifdef LOCALauto I = freopen("in.in","r",stdin),O = freopen("out.out","w",stdout);
#elseauto I = stdin,O = stdout;
#endif
using ll = long long;using ull = unsigned long long;
using db = double;using ldb = long double;
using comp = complex<db>;
const int N = 4e6 + 10;
const db pi = acos(-1);
int n,m,to[N],ct;
comp a[N],b[N];
void fft(comp *a,int n,int type){rep(i,0,n-1,1) if(i < to[i]) swap(a[i],a[to[i]]);for(int mid = 1;mid < n;mid <<= 1){comp W = comp(cos(pi/mid),type*sin(pi/mid));for(int Res = mid<<1,j = 0;j < n;j += Res){comp w = comp(1,0);for(int k = 0;k < mid; ++k,w = w*W){comp x = a[j+k],y = w*a[j+mid+k];a[j+k] = x + y;a[j+mid+k] = x-y;}}}
}
signed main(){cin.tie(nullptr)->sync_with_stdio(false);cin>>n>>m;rep(i,0,n,1) cin>>a[i];rep(i,0,m,1) cin>>b[i];int lim = 1;while(lim <= (n+m)) lim <<= 1,ct++;rep(i,0,lim-1,1) to[i] = (to[i>>1]>>1)|((i&1)<<(ct-1));fft(a,lim,1);fft(b,lim,1);rep(i,0,lim,1) a[i] = a[i]*b[i];fft(a,lim,-1);rep(i,0,n+m,1) cout<<(int)(a[i].real()/lim+0.5)<<' ';
}
实际应用中还可以预处理单位根,但我写丑了,跑的还没不预处理的快(甚至不如递归)
code
#include<bits/stdc++.h>
using namespace std;
#define rep(i,s,t,p) for(int i = s;i <= t;i += p)
#define drep(i,s,t,p) for(int i = s;i >= t;i -= p)
#ifdef LOCALauto I = freopen("in.in","r",stdin),O = freopen("out.out","w",stdout);
#elseauto I = stdin,O = stdout;
#endif
using ll = long long;using ull = unsigned long long;
using db = double;using ldb = long double;
using comp = complex<db>;
const int N = 4e6 + 10;
const db pi = acos(-1);
int n,m,to[N],ct;
comp a[N],b[N],urt[N],iurt[N];
void fft(comp *a,comp *urt,int n){rep(i,0,n-1,1) if(i < to[i]) swap(a[i],a[to[i]]);for(int mid = 1;mid < n;mid <<= 1){for(int Res = mid<<1,j = 0;j < n;j += Res){for(int k = 0;k < mid; ++k){comp x = a[j+k],y = urt[n/mid*k]*a[j+mid+k];a[j+k] = x + y;a[j+mid+k] = x-y;}}}
}
signed main(){cin.tie(nullptr)->sync_with_stdio(false);cin>>n>>m;rep(i,0,n,1) cin>>a[i];rep(i,0,m,1) cin>>b[i];int lim = 1;while(lim <= (n+m)) lim <<= 1,ct++;comp W = comp(cos(pi/lim),sin(pi/lim));urt[0] = comp(1,0);rep(i,1,lim,1) urt[i] = urt[i-1]*W;W = conj(W);iurt[0] = comp(1,0);rep(i,1,lim,1) iurt[i] = iurt[i-1]*W;rep(i,0,lim-1,1) to[i] = (to[i>>1]>>1)|((i&1)<<(ct-1));fft(a,urt,lim);fft(b,urt,lim);rep(i,0,lim,1) a[i] = a[i]*b[i];fft(a,iurt,lim);rep(i,0,n+m,1) cout<<(int)(a[i].real()/lim+0.5)<<' ';
}
例题:【模板】高精度乘法 | A*B Problem 升级版
没有压位,和高精一样写即可,就是乘法由暴力乘换成了 fft。
code
#include<bits/stdc++.h>
using namespace std;
#define rep(i,s,t,p) for(int i = s;i <= t;i += p)
#define drep(i,s,t,p) for(int i = s;i >= t;i -= p)
#ifdef LOCALauto I = freopen("in.in","r",stdin),O = freopen("out.out","w",stdout);
#elseauto I = stdin,O = stdout;
#endif
using ll = long long;using ull = unsigned long long;
using db = double;using ldb = long double;
using comp = complex<db>;
const db pi = acos(-1);
const int N = 4e6 + 10;
string s1,s2;
int n,m,ct,to[N],ans[N];comp a[N],b[N];
void fft(comp *a,int type,int n){rep(i,0,n,1) if(i < to[i]) swap(a[i],a[to[i]]);for(int mid = 1;mid < n;mid <<= 1){comp W = comp(cos(pi/mid),type*sin(pi/mid));for(int Res = mid<<1,j = 0;j < n;j += Res){comp w = comp(1,0);for(int k = 0;k < mid; ++k,w = w*W){auto x = a[j+k],y = w*a[j+mid+k];a[j+k] = x+y;a[j+mid+k] = x-y;}}}
}
signed main(){cin.tie(nullptr)->sync_with_stdio(false);cin>>s1>>s2;n = s1.size() - 1,m = s2.size() - 1;reverse(s1.begin(),s1.end());reverse(s2.begin(),s2.end());rep(i,0,n,1) a[i] = comp(s1[i]-'0',0);rep(i,0,m,1) b[i] = comp(s2[i]-'0',0);int lim = 1;while(lim <= (n+m)) lim <<= 1,ct++;rep(i,0,lim,1) to[i] = (to[i>>1]>>1)|((i&1)<<(ct-1));fft(a,1,lim);fft(b,1,lim);rep(i,0,lim,1) a[i] = a[i]*b[i];fft(a,-1,lim);rep(i,0,lim,1) ans[i] = (int)(a[i].real()/lim+0.5);rep(i,0,lim,1) ans[i+1] += ans[i]/10,ans[i] %= 10;int now = lim;while(!ans[now]) now--;drep(i,now,0,1) cout<<ans[i];cout<<'\n';
}
NTT
其实和 FFT 没啥区别,但是建议学明白 FFT 后再看。
前置知识是 原根,顺便不要脸地挂一下之前写过的原根笔记。
考虑 FFT 的弊端,复数为 double
类型,常数大,而且有精度误差。
那么什么时候可以用整数代替 double
,当有模数的时候就行了。
为什么用单位根当 FFT 代入的那几个数,因为单位根有许多优秀的性质,而在模意义下,原根同样有类似的性质。
以下记模数为 \(Mod\),\(g\) 为 \(Mod\) 的原根(假设下文所说的东西在模 \(Mod\) 的意义下一定都存在)。
其实就是将 \(w_n\) 替换为 \(g^{\frac{Mod-1}{n}}\),然后那几个性质一样证明就好了。
但是我们要保证 \(n|(Mod-1)\),而又因为 \(n\) 一直为 \(2\) 的整数次幂,所以就要保证 \(Mod-1\) 为 \(p2^k\),其中 \(2^k\ge n\)
就比如常用模数 \(998244353=7\times17\times2^{23}+1\),所以它最多可以做 \(n=8388608=2^23\) 的 NTT。
代码实现和 FFT 没什么两样,还是以 【模板】多项式乘法(FFT) 为例。
确实比 FFT 快一点,但是我的 FFT 因为懒所以没有手写 complex 类,不知道手写以后会不会和 NTT 差不多。
code
#include<bits/stdc++.h>
using namespace std;
#define rep(i,s,t,p) for(int i = s;i <= t;i += p)
#define drep(i,s,t,p) for(int i = s;i >= t;i -= p)
#ifdef LOCALauto I = freopen("in.in","r",stdin),O = freopen("out.out","w",stdout);
#elseauto I = stdin,O = stdout;
#endif
using ll = long long;using ull = unsigned long long;
using db = double;using ldb = long double;
const int N = 4e6 + 10,Mod = 998244353,G = 3,Gi = 332748118;
int n,m,a[N],b[N],to[N],ct;
int qpow(int a,int b,int Mod = Mod){int res = 1;for(;b;b >>= 1,a = 1ll*a*a%Mod)if(b&1) res = 1ll*res*a%Mod;return res;
}
void ntt(int *a,int n,bool type){rep(i,0,n-1,1) if(i < to[i]) swap(a[i],a[to[i]]);for(int mid = 1;mid < n;mid <<= 1){int W = qpow(type?G:Gi,(Mod-1)/(mid<<1));for(int Res = mid<<1,j = 0;j < n;j += Res){int w = 1;for(int k = 0;k < mid; ++k,w = 1ll*w*W%Mod){int x = a[j+k],y = 1ll*w*a[j+k+mid]%Mod;a[j+k] = (x+y)%Mod;a[j+k+mid] = (x-y+Mod)%Mod;}}}
}
signed main(){cin.tie(nullptr)->sync_with_stdio(false);cin>>n>>m;rep(i,0,n,1) cin>>a[i];rep(i,0,m,1) cin>>b[i];int lim = 1;while(lim <= (n+m)) lim <<= 1,ct++;rep(i,0,lim,1) to[i] = (to[i>>1]>>1)|((i&1)<<(ct-1));ntt(a,lim,true);ntt(b,lim,true);rep(i,0,lim,1) a[i] = 1ll*a[i]*b[i]%Mod;ntt(a,lim,false);int Inv = qpow(lim,Mod-2,Mod);rep(i,0,n+m,1) cout<<1ll*a[i]*Inv%Mod<<' ';cout<<'\n';
}