代码均为做严格测试,仅供参考
分治法基本原理
将原问题分解为几个规模较小但类似于原问题的子问题,递归的求解这些子问题。然后再合并这些子问题的解来建立原问题的解。递归求解这些子问题,然后再合并这些子问题的解来建立原问题的解。
分治法在分层递归时都有三个步骤:
分解原问题为若干子问题,这些子问题是原问题规模较小的实例。解决这些子问题,递归的求解各个子问题。然而若子问题的规模足够小。则直接求解。合并这些子问题的解成原问题的解。
问题描述
两个N次多项式相乘,最直接的复杂度为
O(n2)
,运用傅里叶变换,则可以吧多项式相乘的复杂度转化为
nlog(n)
输入输出均采用系数表达,假设n是2的幂,否则通过添加系数为0的高阶系数。算法准备过程如下:
加倍次数界:通过添加n个系数为0的高阶系数,把多项式A(x)和B(x)变为次数界为2n的多项式,并构造其系数表达。求值:通过应用2n阶的FFT计算出A(x)和B(x)的长度为2n的点值表达。这些点值表达式中包含了两个多项式在2n次单位根处的取值。逐点相乘:把A(x)和B(x)的值逐点相乘,可以计算出多项式C(x)=A(x)B(x)长度为2n的点值表达,这个表示中包含了C(x)在每个2n单位根处的值。插值:通过对2n个点值对应用fft,计算其逆DFT,就可以构造出多项式C(x)的系数表达。
基本概念和定理
单位复数根
n次单位复数根是满足
ωn=1
的复数
ω
,这些根正好有n个,分别是
e2πikn(k=0,1,2...n−1)
其中
e2πin
被称作主n次单位根。
ωj∗ωk=ω(k+j)modn
消去引理
对于任意整数n>=0和k>=0,以及d>=0
ωdkdn=ωkn
折半引理
如果n>0为偶数,那么n个n次单位复数根的平方的集合就是n/2个n/2次单位复数根的集合。
对任意非负整数k,我们有
(ωkn)2=ωkn/2
求和引理
对任意整数n>=1和不能被n整除的非负整数k,有
∑n−1j=0(ωkn)j=0
算法实现
DFT
我们希望计算次数界为n的多项式A(x)=
∑n−1j=0ajxj
在n个n次单位复数根处的值,假设A以系数形式给出:
a=(a0,a1,a2...an−1)
。接下来对k=0,1,2,..n-1,定义结果
yk
:
yk=A(ωkn)=∑n−1j=0ajωkjn
向量
y=(y0,y1,...yn−1)
就是系数向量a=(
a0,a1,...,an−1
)的离散傅里叶变换DFT
FFT
通过使用快速傅里叶变换的方法,利用复数单位根的特殊性质,我们就可以在
θ(nlgn)
的时间内计算出DFT(a)。
首先分别定义两个新的次数界为n/2的多项式
A[0](x)=a0+a2+...an−2xn/2−1
A[1](x)=a1+a3+...an−1xn/2−1
分别包含了所有偶数下标的系数和奇数下标的系数。
A(x)=A[0](X2)+xA[1](x2)
因而求A(x)在
ω0n,ω1n,…,ωn−1n
处的值得问题转化为:
求次数界为n/2的多项式
A[0](x)+xA[1](x)
在点
(ω0)2…(ωn−1)2
处的取值
用递归方法计算fft的伪代码如下
RECURSIVE_FFT(
a[])
{
n=
a.lenth
if(n==
1)
return a
wn=e^(
2*
pi*i/n)
w=
1
a0[]=(a0,a2,a4...)
a1[]=(a1,a3,a5...)
y0[]=RECURSIVE_FFT(a0[])
y1[]=RECURSIVE_FFT(a1[])
for k=
0 to n/
2
y[k]=y0[k]+w*y1[k]
y[k+n/
2]=y0[k]-w*y1[k]
w=w*wn
return y
}
计算出逆DFT。将fft算法进行修改,将a与y互换,用
ω−1n替换ωn
,并将计算结果的每个数除以n。
算法复杂度分析以及可能的优化
T(n)=2T(n/2)+θ(n)=θ(nlgn)
从算法实现上来看,整体的时间复杂度无法进行优化。
但是可以把递归的算法改成迭代的形式实现。从而节省栈中的空间。同时迭代算法可以做到常数上的优化。
优化实现主要代码
int rev(
int k,
int n){
int res=
0;
while(n){
int x=k&
1;
res=res*
2+x;
k>>=
1;
n>>=
1;
}
return res;
}
void bit_reverse(
vector<Complex> a,
vector<Complex> A){
int n=(
int)a.size();
A.resize(n);
for(
int k=
0;k<n;k++){
A[rev(k,n-
1)]=a[k];
}
}
vector<Complex> iterative_fft(
vector<Complex> a,
double op){
vector<Complex> A;
bit_reverse(a, A);
int n=(
int)a.size();
for(
int s=
0;(
1<<s)<n;s++){
int m=
1<<s;
int temp=
2*m;
Complex wm=Complex(
cos(pi/m*op),
sin(pi/m*op));
for(
int k=
0;k<n;k+=temp)
{
Complex w=Complex(
1,
0);
for(
int j=
0;j<m;j++)
{
Complex t=w*A[k+j+m];
Complex u=A[k+j];
A[k+j]=u+t;
A[k+j+m]=u-t;
w=w*wm;
}
}
}
return A;
}
整体实现代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <complex>
using namespace std;
#define _ sync_with_stdio(false)
typedef long long ll;
typedef complex<
double> Complex;
const double pi=
acos(-
1);
const int INF=
0x7fffffff;
vector<Complex> recursive_fft(
vector<Complex> a,
double op){
vector<Complex> y;
int n=(
int)a.size();
if(n==
1)
return a;
Complex w=Complex(
1,
0);
Complex wn=Complex(
cos(
2*pi/(n*op)),
sin(
2*pi/(n*op)));
vector<Complex> a0,a1;
for(
int i=
0;i<n;i++){
if(i&
1){
a1.push_back(a[i]);
}
else{
a0.push_back(a[i]);
}
}
vector<Complex> y0=recursive_fft(a0, op);
vector<Complex> y1=recursive_fft(a1, op);
y.resize(n);
for(
int k=
0;k<=n/
2-
1;k++){
y[k]=y0[k]+w*y1[k];
y[k+n/
2]=y0[k]-w*y1[k];
w=w*wn;
}
return y;
}
int rev(
int k,
int n){
int res=
0;
while(n){
int x=k&
1;
res=res*
2+x;
k>>=
1;
n>>=
1;
}
return res;
}
void bit_reverse(
vector<Complex> a,
vector<Complex> A){
int n=(
int)a.size();
A.resize(n);
for(
int k=
0;k<n;k++){
A[rev(k,n-
1)]=a[k];
}
}
vector<Complex> iterative_fft(
vector<Complex> a,
double op){
vector<Complex> A;
bit_reverse(a, A);
int n=(
int)a.size();
for(
int s=
0;(
1<<s)<n;s++){
int m=
1<<s;
int temp=
2*m;
Complex wm=Complex(
cos(pi/m*op),
sin(pi/m*op));
for(
int k=
0;k<n;k+=temp)
{
Complex w=Complex(
1,
0);
for(
int j=
0;j<m;j++)
{
Complex t=w*A[k+j+m];
Complex u=A[k+j];
A[k+j]=u+t;
A[k+j+m]=u-t;
w=w*wm;
}
}
}
return A;
}
int main() {
int n,m;
cout<<
"请输入第一个多项式的长度:";
cin>>n;
cout<<
"请依次输入第一个多项式的系数:";
vector<Complex> s0,s1;
for(
int i=
0;i<n;i++){
double x;
cin>>x;
s0.push_back(Complex(x,
0));
}
cout<<
"请输入第二个多项式的长度:";
cin>>m;
cout<<
"请依次输入第二个多项式的系数:";
for(
int i=
0;i<m;i++){
double x;
cin>>x;
s1.push_back(Complex(x,
0));
}
int MAX=max(n,m);
int temp=
1;
while(temp<MAX){
temp<<=
1;
}
MAX=temp*
2;
for(
int i=n;i<MAX;i++)
s0.push_back(Complex(
0,
0));
for(
int i=m;i<MAX;i++)
s1.push_back(Complex(
0,
0));
vector <Complex> ans0=recursive_fft(s0,
1);
vector<Complex> ans1=recursive_fft(s1,
1);
vector<Complex> ans;
ans.resize(MAX);
for(
int i=
0;i<MAX;i++){
ans[i]=ans0[i]*ans1[i];
}
vector<Complex> res=recursive_fft(ans, -
1);
cout<<
"相乘之后结果序列表示为:";
for(
int i=
0;i<n+m-
1;i++){
if(i==n+m-
2)
cout<<res[i].real()/(MAX)<<endl;
else
cout<<res[i].real()/(MAX)<<
" ";
}
}