美文网首页ACM
ACM数据结构(一)——主席树

ACM数据结构(一)——主席树

作者: hymscott | 来源:发表于2018-01-17 22:14 被阅读383次

    让我们来看一个经典的问题吧:

    给定一个[1,n]的区间,m次操作,操作种类如下:

    1 L R:查询[L,R]的区间和

    2 L R X:将[L,R]的值加上X

    这种经典问题,想必大家学过线段树后都可以轻松解决。然而如果再增加一种操作:

    3 K:回退到第K次修改操作的结果

    可见,如果题目要求回溯到历史版本,那么普通的线段树就不能解决了,因为在每次更新操作后,线段树存储的内容就发生了改变,如果不进行特殊记录,那么这种改变将是永久的。因此,对于这种类型的题目,我们可以用到今天要讨论的数据结构——主席树来进行解决。


    主席树,严格来讲应该叫:函数式线段树,是基于线段树的一种数据结构,常用于处理一些在线问题,关于在线离线的概念参考上一篇文章:在线和离线算法。事实上,主席树就是多个线段树的集合体。

    主席树的实质,就是以最初的线段树作为模板,通过"结点复用“的方式,实现存储多个线段树。

    对于文章开始的问题,观察后可以发现,在2操作进行后,在上一次修改后的线段树上,最多修改O(logn)个结点(最远是从根节点到叶子节点)。如果每次单独新建一个线段树,则会造成重复存储,如图所示:

    原始线段树 修改[6,7] 修改[3,5] 修改[1,9]

    浅蓝色的结点是当前修改操作时访问的结点,白色结点为上一棵线段树的结点。

    如果对每次修改操作无差别复制一棵线段树,那么用于存储节点的开销是巨大的,因为对于单次修改,大部分结点都不曾被访问修改。

    通过“结点复用”的方式,我们可以将这多棵线段树压缩成如下形式:


    开辟新结点 结点复用

    因此第i个线段树只要通过保留除修改路径外的第i-1棵线段树的结点,再新增加至多O(logn)个结点。

    rt[i]保存第i次操作的线段树的根节点,这样,回退到第k次操作等价于rt[i]=rt[k],我们的问题就迎刃而解啦。


    那么,怎么来建立一棵主席树呢?针对文章开始的题目,下面给出实现步骤:

    1. 创建根节点、左右儿子结点数组

    int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];
    

    tot是每次新建的结点编号。
    rt[i]是第i棵线段树的根节点的编号。
    lson[x]和rson[x]是结点x的左右儿子结点的编号。
    v[x]是结点x代表的区间的和。
    lz[x]是结点x的懒惰(lazy)值。
    a[i]是初始的第i个位置的值。
    因为结点每次至多更新O(logn)个,所以数组范围应该在原来的20-50倍左右。

    2.区间更新的pushup和pushdown

    void push_up(int x){
        v[x]=v[lson[x]]+v[rson[x]];
    }
    
    void push_down(int x,int len){
        if(lz[x]){
            v[lson[x]]+=(len>>1)*lz[x];
            v[rson[x]]+=(len-(len>>1))*lz[x];
            lz[lson[x]]+=(len>>1)*lz[x];
            lz[rson[x]]+=(len-(len>>1))*lz[x];
            lz[x]=0;
        }
    }
    

    区间更新基础,不会的可以先了解线段树的区间更新写法。

    3. 建树

    void build(int &x,int l,int r){
        x=++tot;
        lz[x]=0;
        if(l==r){
            v[x]=a[l];
            return;
        }
        int mid=l+r>>1;
        build(lson[x],l,mid);
        build(rson[x],mid+1,r);
        push_up(x);
    }
    

    和线段树的思想是一样的,只是在调用过程中,我们以引用的形式,实现对rt,lson,rson的更新。
    建树的调用如下:

    build(rt[0],1,n);
    

    3. 更新

    void update(int L,int R,int l,int r,int &x,int last,int val){
        x=++tot;
        lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
        if(L<=l&&R>=r){
            v[x]+=(r-l+1)*val;lz[x]+=val;
            return;
        }
        push_down(x,r-l+1);
        int mid=l+r>>1;
        if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
        if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
        push_up(x);
    }
    

    第1行开辟了新的结点,第2行进行了结点复用,last就是上一棵线段树的结点,从根节点向下更新。
    更新的调用如下:

    update(x,y,1,n,rt[i],rt[i-1],w);
    

    4. 查询

    int query(int L,int R,int l,int r,int x){
        if(L<=l&&R>=r){
            return v[x];
        }
        push_down(x,r-l+1);
        int mid=l+r>>1,sum=0;
        if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
        if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
        push_up(x);
        return sum;
    }
    

    查询就是简单的区间查询。
    查询的调用如下:

    query(x,y,1,n,rt[i]);
    

    5. 实现

    #include <iostream>
    #include <cstdio>
    #include <string>
    #include <cstring>
    #include <algorithm>
    #include <queue>
    #include <vector>
    #include <cmath>
    #include <functional>
    #include <map>
    #include <stack>
    #include <ctime>
    #include <sstream>
    #include <bitset>
    
    //#include<ext/pb_ds/assoc_container.hpp>
    
    //#include <bits/stdc++.h>
    
    #define REP(i,j,k) for(int (i)=(j);(i)<=(k);(i)++)
    #define ERP(i,j,k) for(int (i)=(j);(i)>=(k);(i)--)
    #define MEM(a,b) memset(a,b,sizeof(a))
    #define NE putchar('\n')
    #define SP putchar(' ')
    #define fi first
    #define sc second
    #define mkp make_pair
    #define pb push_back
    #define all(a) a.begin(),a.end()
    //#define lson l,mid,x<<1
    //#define rson mid+1,r,x<<1|1
    #define lowbit(x) ((x)&(-(x)))
    #define lc(a) ch[(a)][0]
    #define mod_add(a,b,m) (a+b>m?a+b-m:a+b)
    #define mod_sub(a,b,m) (a-b<0?a-b+m:a-b)
    
    using namespace std;
    //using namespace __gnu_pbds;
    typedef double DB;
    typedef long double LDB;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> PI;
    typedef pair<ll,ll> PLL;
    
    const DB eps=1e-6;
    const DB Pi=acos(-1.0);
    const ll mod=1e9+7;
    const ull ha1=1e9+7;
    const ull ha2=1e9+9;
    const int maxn=1e5+10;
    const int maxm=1e6+10;
    const int inf=1e9+10;
    
    //IO挂
    template<typename Type>inline void read(Type&in){
        in=0;Type f=1;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){in=in*10+ch-'0';ch=getchar();}
        in*=f;
    }
    
    template<typename Type>inline void out(Type o){
        if(o<0){putchar('-');o=-o;}
        if(o>=10) out(o/10);
        putchar(o%10+'0');
    }
    
    /*Header*/
    //printf("%d%c",a[i]," \n"[i==n]);
    
    int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];
    
    void push_up(int x){
        v[x]=v[lson[x]]+v[rson[x]];
    }
    
    void push_down(int x,int len){
        if(lz[x]){
            v[lson[x]]+=(len>>1)*lz[x];
            v[rson[x]]+=(len-(len>>1))*lz[x];
            lz[lson[x]]+=(len>>1)*lz[x];
            lz[rson[x]]+=(len-(len>>1))*lz[x];
            lz[x]=0;
        }
    }
    
    void build(int &x,int l,int r){
        x=++tot;
        lz[x]=0;
        if(l==r){
            v[x]=a[l];
            return;
        }
        int mid=l+r>>1;
        build(lson[x],l,mid);
        build(rson[x],mid+1,r);
        push_up(x);
    }
    
    void update(int L,int R,int l,int r,int &x,int last,int val){
        x=++tot;
        lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
        if(L<=l&&R>=r){
            v[x]+=(r-l+1)*val;lz[x]+=val;
            return;
        }
        push_down(x,r-l+1);
        int mid=l+r>>1;
        if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
        if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
        push_up(x);
    }
    
    int query(int L,int R,int l,int r,int x){
        if(L<=l&&R>=r){
            return v[x];
        }
        push_down(x,r-l+1);
        int mid=l+r>>1,sum=0;
        if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
        if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
        push_up(x);
        return sum;
    }
    
    int x,y,w;
    
    int main(){
        int n,k,opt;
        cin>>n>>k;
        for(int i=1;i<=n;i++){
            cin>>a[i];
        }
        build(rt[0],1,n);
        for(int i=1;i<=k;i++){
            cin>>opt;
            if(opt==1){
                rt[i]=rt[i-1];
                cin>>x>>y;
                cout<<query(x,y,1,n,rt[i])<<endl;
            }
            else if(opt==2){
                cin>>x>>y>>w;
                update(x,y,1,n,rt[i],rt[i-1],w);
            }
            else{
                cin>>x;
                rt[i]=rt[x];
            }
        }
        return 0;
    }
    

    对于第i个操作,方式1通过rt[i-1]更新rt[i],方式2通过引用更新rt[i],方式3通过rt[x]更新rt[i]。

    6. 测试一下~

    input.txt
    10 8
    1 2 3 4 5 6 7 8 9 10
    2 6 7 2
    1 6 7
    2 3 5 4
    1 3 5
    2 1 9 5
    1 1 9
    3 3
    1 1 10
    
    output.txt
    17
    24
    106
    71
    

    正确无误~(blink)


    那么,主席树的入门就到这里了,下面给出poj 2104(静态区间求第K大)的主席树代码,作为参考啦~

    #include <bits/stdc++.h>
    #include <cstdio>
    
    #define fi first
    #define sc second
    #define mkp make_pair
    #define pb push_back
    #define all(a) a.begin(),a.end()
    
    using namespace std;
    typedef long long ll;
    typedef pair<int,int> PI;
    typedef pair<ll,ll> PLL;
    
    const double eps=1e-8;
    const double pi=acos(-1);
    const int mod=1e9+7;
    
    /*Header*/
    
    const int maxn=1e5+10;
    
    int rt[maxn*20],lson[maxn*20],rson[maxn*20],sum[maxn*20];
    int a[maxn],b[maxn];
    int tot;
    
    int n,q;
    
    void build(int &x,int l,int r){
        x=++tot;
        sum[x]=0;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(lson[x],l,mid);
        build(rson[x],mid+1,r);
    }
    
    void update(int &x,int last,int l,int r,int pos){
        x=++tot;
        lson[x]=lson[last];
        rson[x]=rson[last];
        sum[x]=sum[last]+1;
        if(l==r) return;
        int mid=(l+r)>>1;
        if(pos<=mid) update(lson[x],lson[last],l,mid,pos);
        else update(rson[x],rson[last],mid+1,r,pos);
    }
    
    int query(int L,int R,int l,int r,int k){
        if(l==r) return l;
        int mid=(l+r)>>1;
        int cnt=sum[lson[R]]-sum[lson[L]];
        if(k<=cnt) return query(lson[L],lson[R],l,mid,k);
        else return query(rson[L],rson[R],mid+1,r,k-cnt);
    }
    
    int main(){
        int T;
        scanf("%d",&T);
        while(T--){
            scanf("%d %d",&n,&q);
            for(int i=1;i<=n;i++){
                scanf("%d",&a[i]);
                b[i]=a[i];
            }
            sort(b+1,b+1+n);
            int m=unique(b+1,b+1+n)-(b+1);
            tot=0;
            build(rt[0],1,m);
            for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
            for(int i=1;i<=n;i++) update(rt[i],rt[i-1],1,m,a[i]);
            int x,y,k,ans;
            while(q--){
                scanf("%d %d %d",&x,&y,&k);
                ans=query(rt[x-1],rt[y],1,m,k);
                printf("%d\n",b[ans]);
            }
        }
        return 0;
    }
    

    相关文章

      网友评论

        本文标题:ACM数据结构(一)——主席树

        本文链接:https://www.haomeiwen.com/subject/jgtpoxtx.html