本来这里我想写仔细一点,但我想,傅里叶变换确实比较复杂,况且,我也无法在没有任何参照的情况下写出下面的代码,所以我也就贴上代码,然后写写我在实现过程中的一些想法。下面的代码基本上都是和《算法导论》上面的实现一样的。
多项式的两种表示方法
说实话,第一次看到多项式式的点值形式的时候,我是比较惊讶的,因为在我的第一感觉是,点值无法确定一个多项式,但书中用了一个很简单的说明就让我明白了我的直觉是错的。对于多项式点值表达式的乘法也让我印象深刻。其实比较简单,如果,那么对于点,必然存在,也就是对应两个点的直接相乘。
以前学《信号与系统》的时候,我就知道时域卷积对应于频域相乘,但是我还是不太清楚为什么,这里给了我一点直观的感觉。
算法的基本思路
多项式相乘,如果使用直接的方法,那么需要的时间,而点值形式的多项式乘法只需要,这时候我们很自然的想到,能否找到一种快速的方法,实现两种形式的转化。
先考虑从系数形式到点值形式的转化,根据霍纳法则,我们可以在的时间内求得一个点对应的值,但一共有n个点,也需要。这时候,我们想到了一些算法设计的技巧,比如分治法。我们可以把多项式按照奇偶划分成两部分,然后分别求解,然后合并。则看上去很不错,但写出递归式:
这里不好的地方在于,递归过程中,我们要求的n个点没有减少,如果我们能在递归过程中,减少对点的求解,也就是说,让根之间有关系,减少计算,如果能减少到一半,那么递归式变成了:
这样,这个算法就比直接求解要优秀。
这里,选择的特别的根是单位复数根,它有一个很重要的性质。
利用这个性质,可以在计算的过程中,实现减少点的计算。
计算完了,我们还要转回来,根据复数根的性质,我们发现,这个矩阵的逆矩阵很好求,就是虚部取反,除以n,然后问题有回到了前面。
#include <iostream>
#include <complex>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
const double PI = acos(-1);
void fft(vector<complex<double> > &a ,int n, int op){
if(n == 1){
return;
}
//divide
vector<complex<double> > a0(n/2,0);
vector<complex<double> > a1(n/2,0);
for(int i=0;i<n/2;i++){
a0[i] = a[i*2];
a1[i] = a[i*2+1];
}
//conquer
fft(a0,n/2,op);
fft(a1,n/2,op);
//merge
complex<double> w(1,0); //通过巧妙的选取根来减少计算
complex<double> wn(cos(2 * PI / n),sin(2 * PI * op /n));
for(int i=0;i<n/2;i++){ //如果选取的根不特殊的话,需要对n个值分别进行计算
a[i] = a0[i] + w*a1[i];
a[i+n/2] = a0[i] - w * a1[i]; //根据单位负数根的性质可以很快求出另一个根对应的值
w = w * wn;
}
}
void fft2(vector<complex<double> > &a ,int n, int op){
//开始输入的时候位置为对应计数比特翻转之后的值
int len = (int)log2(n) - 1;
vector<int> rev(n,0);
for(int i=0;i<n;i++){ //计算翻转位置
//利用和位置为i/2的翻转数的关系,提高速度
rev[i] = (rev[i>>1]>>1) | ((i&0x01)<<len);
}
//改变输入数组的顺序
for(int i=0;i<n;i++){
if(i < rev[i]){
swap(a[i],a[rev[i]]);
}
}
//模拟递归的计算顺序,从底层开始,一个log2(n)层
for(int i = 1;i< n; i *= 2){
complex<double> wn(cos( PI / i),sin(PI * op / i));
for(int k=0;k<n;k+= 2 * i){
complex<double> w(1,0);
for(int j=0;j<i;j++){
complex<double> t = w * a[k + j + i];
complex<double> u = a[k + j];
a[k+j] = u + t;
a[k+j+i] = u - t;
w *= wn;
}
}
}
if(op == -1){
for(int i=0;i<n;i++){
a[i] /= n;
}
}
}
/*
2 2
1 2 1
2 0 1
2 4 3 2 1
*/
int main(int argc, char const *argv[])
{
int n,m;
cin>>n>>m;
int n1= 1;
while(n1<=m+n){
n1 <<= 1;
}
vector<complex<double> > v1(n1,0);
vector<complex<double>> v2(n1,0);
vector<complex<double>> res(n1,0);
int tmp;
for(int i=0;i<=n;i++){
cin>>v1[i];
}
for(int i=0;i<=m;i++){
cin>>v2[i];
}
//将系数形式转化为点值形式
fft(v1,n1,1);
fft(v2,n1,1);
//利用点值形式计算多项式乘法
for(int i=0;i<n1;i++){
res[i] = v1[i] * v2[i];
// cout<<v1[i]<<" "<<v2[i]<<" "<<res[i]<<endl;
}
//将点值表达转化为系数形式
fft(res,n1,-1);
for (int i = 0; i <= m+n; i++){
printf("%.0lf ", floor(res[i].real()/n1));
}
system("pause");
return 0;
}
代码的注释写得挺多的,我个人只能说对这个算法有一个基本的理解,主要是为了记录学习的过程。如果有错误,请多多包涵。
网友评论