美文网首页
Fast Fourier transform

Fast Fourier transform

作者: fo0Old | 来源:发表于2017-08-29 10:15 被阅读0次
namespace NTT
{
    const int mod=119<<23|1; //998244353
    const int g=3;           //原根
    int wn[25];              //顺时针旋转因子
    int wr[25];              //逆时针旋转因子
    int inv[25];             //2^i的逆元

    void init()
    {
        inv[0]=wr[1]=1;
        inv[1]=mod-mod/2;
        for(int i=2;i<=23;++i)
            inv[i]=md(1ll*inv[i-1]*inv[1]);
        wn[23]=15311432;    //qpow(3,119)
        wr[23]=469870224;   //inv[wn[23]]
        for(int i=22;i>=1;--i)
        {
            wn[i]=1ll*wn[i+1]*wn[i+1]%mod;
            wr[i]=1ll*wr[i+1]*wr[i+1]%mod;
        }
    }

    void ntt(ll xs[],int m,int logm,bool dft=true)
    {
        int invm=inv[logm];
        for(int i=0,j=m>>1,k=0;i<m;++i)
        {
            if(k>i)swap(xs[k],xs[i]);
            while(k&j)k^=j,j>>=1;
            k|=j,j=m>>1;
        }
        for(int i=1,rat=2;rat<=m;++i,rat<<=1)
        {
            int l=rat>>1,w=dft?wn[i]:wr[i];
            for(int j=0,wx=1;j<l;++j,wx=md(1ll*wx*w))
                for(int k=j;k<m;k+=rat)
                {
                    int t=md(xs[k]-md(wx*xs[k+l]));
                    xs[k]=md(xs[k]+md(wx*xs[k+l]));
                    xs[k+l]=t;
                    if(m==rat && !dft)
                    {
                        xs[k]=md(xs[k]*invm);
                        xs[k+l]=md(xs[k+l]*invm);
                    }
                }
        }
    }
};

多项式乘法

struct Polynomial
{
    static const int __=2.7e5;        //>2^(logn+2);
    static const int mod=119<<23|1; //998244353
    static const int g=3;           //原根
    static int wn[25];              //顺时针旋转因子
    static int wr[25];              //逆时针旋转因子
    static int inv[25];             //2^i的逆元

    static ll md(ll x)
    {
        if(x<=-mod || x>=mod)
            x%=mod;
        if(x<0)x+=mod;
        return x;
    }

    static ll qpow(ll x,ll y)
    {
        ll res=1;
        for(;y;y>>=1,x=md(x*x))
            if(y&1)res=md(res*x);
        return res;
    }

    static void init()
    {
        inv[0]=wr[1]=1;
        inv[1]=mod-mod/2;
        for(int i=2;i<=23;++i)
            inv[i]=md(1ll*inv[i-1]*inv[1]);
        wn[23]=15311432;    //qpow(3,119)
        wr[23]=469870224;   //inv[wn[23]]
        for(int i=22;i>=1;--i)
        {
            wn[i]=1ll*wn[i+1]*wn[i+1]%mod;
            wr[i]=1ll*wr[i+1]*wr[i+1]%mod;
        }
    }

    ll a[__];int n;

    Polynomial() {}

    Polynomial(ll b[],int _n) {set(b,_n);}

    void set(int _n){n=_n;}

    void set(ll b[],int _n)
    {
        n=_n;
        for(int i=0;i<=n;++i)
            a[i]=b[i];
        simpfy();
    }

    void simpfy(){for(;n && !a[n];--n);}

    ll& operator[](int x){return a[x];}

    Polynomial operator+(const Polynomial &b)const
    {
        static Polynomial c;
        c.set(max(n,b.n));
        for(int i=0;i<=c.n;++i)
        {
            c[i]=(i<=n?a[i]:0)+(i<=b.n?b.a[i]:0);
            c[i]=md(c[i]);
        }
        c.simpfy();
        return c;
    }

    int m,logm;

    void ntt(ll xs[],bool dft=true)
    {
        int invm=inv[logm];
        for(int i=0,j=m>>1,k=0;i<m;++i)
        {
            if(k>i)swap(xs[k],xs[i]);
            while(k&j)k^=j,j>>=1;
            k|=j,j=m>>1;
        }
        for(int i=1,rat=2;rat<=m;++i,rat<<=1)
        {
            int l=rat>>1,w=dft?wn[i]:wr[i];
            for(int j=0,wx=1;j<l;++j,wx=md(1ll*wx*w))
                for(int k=j;k<m;k+=rat)
                {
                    int t=md(xs[k]-md(wx*xs[k+l]));
                    xs[k]=md(xs[k]+md(wx*xs[k+l]));
                    xs[k+l]=t;
                    if(m==rat && !dft)
                    {
                        xs[k]=md(xs[k]*invm);
                        xs[k+l]=md(xs[k+l]*invm);
                    }
                }
        }
    }

    //O((n+m)log(n+m))乘法
    Polynomial operator*(const Polynomial &b)
    {
        static Polynomial c;
        static ll d[__];
        for(m=1,logm=0;m<=n+b.n;)
            m<<=1,++logm;
        c.set(m-1);
        for(int i=0;i<m;++i)
        {
            c.a[i]=(i<=n)?a[i]:0;
            d[i]=(i<=b.n)?b.a[i]:0;
        }
        ntt(c.a),ntt(d);
        for(int i=0;i<m;++i)
            c.a[i]=md(c.a[i]*d[i]);
        ntt(c.a,false);
        c.set(n+b.n);
        //c.simpfy();
        return c;
    }

    void print()
    {
        for(int i=0;i<=n;++i)
            pf("%lld%c",a[i]," \n"[i==n]);
    }

    void clear()
    {
        for(int i=0;i<=n;++i)
            a[i]=0;
    }
}A,B,C;
int Polynomial::inv[25];
int Polynomial::wn[25];
int Polynomial::wr[25];

Polynomial::init();

相关文章

网友评论

      本文标题:Fast Fourier transform

      本文链接:https://www.haomeiwen.com/subject/gmfodxtx.html