BZOJ-3451: Tyvj1953 Normal(FFT+点

作者: AmadeusChan | 来源:发表于2019-03-13 12:57 被阅读1次

    题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3451

    这题目实在是太神了!由于如果某点x出现在y的子树上贡献1的消费,那么说明x是路径(x,y)上最早选到的,那么答案就是sigma(1/dist(u,v)),然后点分治+FFT统计之,O(n log^2 n)

    代码:

    #include <cstdio>
    
    #include <cstring>
    
    #include <cmath>
    
    #include <cstdlib>
    
    #include <vector>
    
     
    
    using namespace std ;
    
     
    
    #define travel( x ) for ( edge *p = head[ x ] ; p ; p = p -> next )
    
    #define rep( i , x ) for ( int i = 0 ; i ++ < x ; )
    
    #define REP( i , l , r ) for ( int i = l ; i <= r ; ++ i )
    
    #define Rep( i , x ) for ( int i = 0 ; i < x ; ++ i )
    
    #define com( a , b ) ( com ) { a , b }
    
     
    
    typedef long long ll ;
    
     
    
    const int maxn = 101000 ;
    
    const double PI = acos( -1.0 ) ;
    
     
    
    struct com {
    
        double a , b ;
    
        com operator * ( const com &x ) const {
    
            return com( ( a * x.a - b * x.b ) , ( b * x.a + a * x.b ) ) ;
    
        }
    
        com operator + ( const com &x ) const {
    
            return com( ( a + x.a ) , ( b + x.b ) ) ;
    
        }
    
        com operator - ( const com &x ) const {
    
            return com( ( a - x.a ) , ( b - x.b ) ) ;
    
        }
    
    } A[ maxn ] ;
    
     
    
    int tra[ maxn ] ;
    
     
    
    inline void FFT( com *a , int N , bool flag ) {
    
        Rep( i , N ) tra[ i ] = 0 ;
    
        for ( int i = 1 , j = N >> 1 ; i < N ; i <<= 1 , j >>= 1 ) for ( int k = i ; k < ( i << 1 ) ; ++ k ) tra[ k ] = j ;
    
        for ( int i = 1 ; i < N ; i <<= 1 ) for ( int j = 0 ; j < i ; ++ j ) tra[ j + i ] |= tra[ j ] ;
    
        Rep( i , N ) A[ i ] = a[ tra[ i ] ] ;
    
        double pi = flag ? PI : ( - PI ) ;
    
        com e , w , rec , ret ;
    
        for ( int i = 1 ; i < N ; i <<= 1 ) {
    
            e = com( cos( ( 2.0 * pi ) / double( i << 1 ) ) , sin( ( 2.0 * pi ) / double( i << 1 ) ) ) , w = com( 1 , 0 ) ;
    
            for ( int j = 0 ; j < i ; ++ j , w = w * e ) {
    
                for ( int k = j ; k < N ; k += ( i << 1 ) ) {
    
                    rec = A[ k ] , ret = w * A[ k + i ] ;
    
                    A[ k ] = rec + ret , A[ k + i ] = rec - ret ;
    
                }
    
            }
    
        }
    
        if ( ! flag ) Rep( i , N ) A[ i ].a /= double( N ) ;
    
        Rep( i , N ) a[ i ] = A[ i ] ;
    
    }
    
     
    
    com a[ maxn ] , b[ maxn ] , c[ maxn ] ;
    
    ll ans[ maxn ] ;
    
      
    
    struct edge {
    
        edge *next ;
    
        int t ;
    
    } E[ maxn << 1 ] ;
    
     
    
    edge *pt = E , *head[ maxn ] ;
    
     
    
    inline void add( int s , int t ) {
    
        pt -> t = t , pt -> next = head[ s ] ; head[ s ] = pt ++ ;
    
    }
    
     
    
    inline void addedge( int s , int t ) {
    
        add( s , t ) , add( t , s ) ;
    
    }
    
     
    
    bool del[ maxn ] ;
    
    int n , size[ maxn ] , root , rt ;
    
     
    
    void gets( int now , int fa ) {
    
        size[ now ] = 1 ;
    
        travel( now ) if ( ! del[ p -> t ] && p -> t != fa ) {
    
            gets( p -> t , now ) ;
    
            size[ now ] += size[ p -> t ] ;
    
        }
    
    }
    
     
    
    void getrt( int now , int fa ) {
    
        if ( root ) return ;
    
        int ret = size[ rt ] / 2 ;
    
        bool flag = ( size[ rt ] - size[ now ] ) <= ret ;
    
        travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) {
    
            if ( size[ p -> t ] > ret ) flag = false ;
    
            getrt( p -> t , now ) ;
    
        }
    
        if ( flag ) root = now ;
    
    }
    
     
    
    int h[ maxn ] , cnt[ maxn ] , mh , m , H[ maxn ] ;
    
     
    
    void geth( int now , int fa ) {
    
        if ( h[ now ] > mh ) mh = h[ now ] ;
    
        travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) {
    
            h[ p -> t ] = h[ now ] + 1 ;
    
            geth( p -> t , now ) ;
    
        }
    
    }
    
     
    
    vector < int > sub[ maxn ] , tak[ maxn ] ;
    
     
    
    void getsub( int now , int fa , int num ) {
    
        sub[ num ].push_back( now ) , mh = max( mh , h[ now ] ) ;
    
        travel( now ) if ( p -> t != fa && ! del[ p -> t ] ) getsub( p -> t , now , num ) ;
    
    }
    
     
    
    void solve( int now ) {
    
        gets( now , 0 ) ;
    
        root = 0 , rt = now ; getrt( now , 0 ) ;
    
        h[ root ] = 0 ; geth( root , 0 ) ;
    
        REP( i , 0 , mh ) tak[ i ].clear(  ) , cnt[ i ] = 0 ;
    
        int Mh = mh ;
    
        cnt[ 0 ] = 1 , ans[ 0 ] ++ ;
    
        travel( root ) if ( ! del[ p -> t ] ) {
    
            sub[ p -> t ].clear(  ) ;
    
            mh = 0 ; getsub( p -> t , root , p -> t ) ;
    
            tak[ mh ].push_back( p -> t ) ;
    
            H[ p -> t ] = mh ;
    
        }
    
        REP( i , 0 , Mh ) Rep( j , tak[ i ].size(  ) ) {
    
            int x = tak[ i ][ j ] , m ;
    
            for ( m = 1 ; m <= H[ x ] ; m <<= 1 ) ; m <<= 1 ;
    
            Rep( k , m ) a[ k ] = b[ k ] = com( 0 , 0 ) ;
    
            REP( k , 0 , H[ x ] ) a[ k ].a = double( cnt[ k ] ) ;
    
            Rep( k , sub[ x ].size(  ) ) b[ h[ sub[ x ][ k ] ] ].a += 1.0 ;
    
            FFT( a , m , true ) , FFT( b , m , true ) ;
    
            Rep( k , m ) c[ k ] = a[ k ] * b[ k ] ;
    
            FFT( c , m , false ) ;
    
            Rep( k , m ) ans[ k ] += int( c[ k ].a + 0.5 ) * 2 ;
    
            Rep( k , sub[ x ].size(  ) ) cnt[ h[ sub[ x ][ k ] ] ] ++ ;
    
        }
    
        del[ root ] = true ;
    
        travel( root ) if ( ! del[ p -> t ] ) solve( p -> t ) ;
    
    }
    
     
    
    int main(  ) {
    
        memset( head , 0 , sizeof( head ) ) , memset( ans , 0 , sizeof( ans ) ) ;
    
        scanf( "%d" , &n ) ;
    
        REP( i , 2 , n ) {
    
            int s , t ; scanf( "%d%d" , &s , &t ) ; addedge( ++ s , ++ t ) ;
    
        }
    
        memset( del , false , sizeof( del ) ) ;
    
        solve( 1 ) ;
    
        double Ans = 0 ;
    
        REP( i , 0 , n ) Ans += double( ans[ i ] ) / double( i + 1 ) ;
    
        printf( "%.4f\n" , Ans ) ;
    
        return 0 ;
    
    }  
    

    相关文章

      网友评论

        本文标题:BZOJ-3451: Tyvj1953 Normal(FFT+点

        本文链接:https://www.haomeiwen.com/subject/qlmzpqtx.html