美文网首页
Fast Fourier transform

Fast Fourier transform

作者: fo0Old | 来源:发表于2017-08-29 10:15 被阅读0次
    namespace NTT
    {
        const int mod=119<<23|1; //998244353
        const int g=3;           //原根
        int wn[25];              //顺时针旋转因子
        int wr[25];              //逆时针旋转因子
        int inv[25];             //2^i的逆元
    
        void init()
        {
            inv[0]=wr[1]=1;
            inv[1]=mod-mod/2;
            for(int i=2;i<=23;++i)
                inv[i]=md(1ll*inv[i-1]*inv[1]);
            wn[23]=15311432;    //qpow(3,119)
            wr[23]=469870224;   //inv[wn[23]]
            for(int i=22;i>=1;--i)
            {
                wn[i]=1ll*wn[i+1]*wn[i+1]%mod;
                wr[i]=1ll*wr[i+1]*wr[i+1]%mod;
            }
        }
    
        void ntt(ll xs[],int m,int logm,bool dft=true)
        {
            int invm=inv[logm];
            for(int i=0,j=m>>1,k=0;i<m;++i)
            {
                if(k>i)swap(xs[k],xs[i]);
                while(k&j)k^=j,j>>=1;
                k|=j,j=m>>1;
            }
            for(int i=1,rat=2;rat<=m;++i,rat<<=1)
            {
                int l=rat>>1,w=dft?wn[i]:wr[i];
                for(int j=0,wx=1;j<l;++j,wx=md(1ll*wx*w))
                    for(int k=j;k<m;k+=rat)
                    {
                        int t=md(xs[k]-md(wx*xs[k+l]));
                        xs[k]=md(xs[k]+md(wx*xs[k+l]));
                        xs[k+l]=t;
                        if(m==rat && !dft)
                        {
                            xs[k]=md(xs[k]*invm);
                            xs[k+l]=md(xs[k+l]*invm);
                        }
                    }
            }
        }
    };
    

    多项式乘法

    struct Polynomial
    {
        static const int __=2.7e5;        //>2^(logn+2);
        static const int mod=119<<23|1; //998244353
        static const int g=3;           //原根
        static int wn[25];              //顺时针旋转因子
        static int wr[25];              //逆时针旋转因子
        static int inv[25];             //2^i的逆元
    
        static ll md(ll x)
        {
            if(x<=-mod || x>=mod)
                x%=mod;
            if(x<0)x+=mod;
            return x;
        }
    
        static ll qpow(ll x,ll y)
        {
            ll res=1;
            for(;y;y>>=1,x=md(x*x))
                if(y&1)res=md(res*x);
            return res;
        }
    
        static void init()
        {
            inv[0]=wr[1]=1;
            inv[1]=mod-mod/2;
            for(int i=2;i<=23;++i)
                inv[i]=md(1ll*inv[i-1]*inv[1]);
            wn[23]=15311432;    //qpow(3,119)
            wr[23]=469870224;   //inv[wn[23]]
            for(int i=22;i>=1;--i)
            {
                wn[i]=1ll*wn[i+1]*wn[i+1]%mod;
                wr[i]=1ll*wr[i+1]*wr[i+1]%mod;
            }
        }
    
        ll a[__];int n;
    
        Polynomial() {}
    
        Polynomial(ll b[],int _n) {set(b,_n);}
    
        void set(int _n){n=_n;}
    
        void set(ll b[],int _n)
        {
            n=_n;
            for(int i=0;i<=n;++i)
                a[i]=b[i];
            simpfy();
        }
    
        void simpfy(){for(;n && !a[n];--n);}
    
        ll& operator[](int x){return a[x];}
    
        Polynomial operator+(const Polynomial &b)const
        {
            static Polynomial c;
            c.set(max(n,b.n));
            for(int i=0;i<=c.n;++i)
            {
                c[i]=(i<=n?a[i]:0)+(i<=b.n?b.a[i]:0);
                c[i]=md(c[i]);
            }
            c.simpfy();
            return c;
        }
    
        int m,logm;
    
        void ntt(ll xs[],bool dft=true)
        {
            int invm=inv[logm];
            for(int i=0,j=m>>1,k=0;i<m;++i)
            {
                if(k>i)swap(xs[k],xs[i]);
                while(k&j)k^=j,j>>=1;
                k|=j,j=m>>1;
            }
            for(int i=1,rat=2;rat<=m;++i,rat<<=1)
            {
                int l=rat>>1,w=dft?wn[i]:wr[i];
                for(int j=0,wx=1;j<l;++j,wx=md(1ll*wx*w))
                    for(int k=j;k<m;k+=rat)
                    {
                        int t=md(xs[k]-md(wx*xs[k+l]));
                        xs[k]=md(xs[k]+md(wx*xs[k+l]));
                        xs[k+l]=t;
                        if(m==rat && !dft)
                        {
                            xs[k]=md(xs[k]*invm);
                            xs[k+l]=md(xs[k+l]*invm);
                        }
                    }
            }
        }
    
        //O((n+m)log(n+m))乘法
        Polynomial operator*(const Polynomial &b)
        {
            static Polynomial c;
            static ll d[__];
            for(m=1,logm=0;m<=n+b.n;)
                m<<=1,++logm;
            c.set(m-1);
            for(int i=0;i<m;++i)
            {
                c.a[i]=(i<=n)?a[i]:0;
                d[i]=(i<=b.n)?b.a[i]:0;
            }
            ntt(c.a),ntt(d);
            for(int i=0;i<m;++i)
                c.a[i]=md(c.a[i]*d[i]);
            ntt(c.a,false);
            c.set(n+b.n);
            //c.simpfy();
            return c;
        }
    
        void print()
        {
            for(int i=0;i<=n;++i)
                pf("%lld%c",a[i]," \n"[i==n]);
        }
    
        void clear()
        {
            for(int i=0;i<=n;++i)
                a[i]=0;
        }
    }A,B,C;
    int Polynomial::inv[25];
    int Polynomial::wn[25];
    int Polynomial::wr[25];
    
    Polynomial::init();
    

    相关文章

      网友评论

          本文标题:Fast Fourier transform

          本文链接:https://www.haomeiwen.com/subject/gmfodxtx.html