美文网首页动态规划
「数据结构进阶」例题之树形数据结构

「数据结构进阶」例题之树形数据结构

作者: 云中翻月 | 来源:发表于2019-02-12 20:48 被阅读54次
    0x40「数据结构进阶」例题

    几点总结
    1 数据结构有两个用处:组织特定类型的数据方便调用高效的实现数据的增删改查。换句话说,这两个用处指向了思考数据结构问题时的逻辑过程,即先想到朴素算法(可能是模拟,贪心,复杂度较高的DP方程等等),然后我们往往时因为数据范围的限制而使用数据结构。也就是说,数据结构本质上是一种对算法实现过程的优化
    2 数据结构问题中,很多数据结构维护的信息有相似处,这时需要根据情况选取特定的数据结构,同时充分考虑代码量实现难度算法常数等因素,最终找到最合适的数据结构。
    3 所有树形数据结构还有一个明显的特征:擅于且仅擅于维护符合“区间加法”的信息。具体地说,就是两个小区间的最值再取最值可以得到一个大区间的最值,两个小区间的和再求和可以得到一个大区间的和。然而,若维护信息不满足“区间加法”,树形数据结构就会在时间复杂度上产生退化。例如,已知两个小区间的众数,难以在O(1)的时间内完成合并两个小区间得到大区间众数的过程。

    并查集

    当问题满足关系具有传递性无向性时,并查集时维护关系的一个有用工具。(PS:若是关系具有传递性和有向性,这时可以思考图论建模有向图强连通分量的tarjan算法以及k-sat问题
    并查集的实现非常简单,且属于基础内容,这里不再赘述。不过值得一提的是,并查集合并过程中优化时间复杂度的思想非常重要,很多复杂的问题都会用到。例如,我们可以通过并查集的路径压缩特性过滤很多无用信息。也可以将”按秩合并”推广到启发式合并(启发式合并指:将大的结构和小的结构合并时,将小的结构向大的结构合并,同时只增加小的结构的查询费用)
    同时使用路径压缩和按秩合并优化的并查集的单次查询/合并时间复杂度趋近反阿克曼函数(近似为常数),但我们遇到的问题中,往往只要其中的一种合并方式即可。
    首先给出路径压缩优化的并查集代码,单次查询/合并时间复杂度:O(logn)

    int fa[SIZE];
    
    for (int i = 0; i <= n; i++) fa[i] = i;
    
    int get(int x) {
        if (x == fa[x]) return x;
        return fa[x] = get(fa[x]);
    }
    
    void merge(int x, int y) {
        fa[get(x)] = get(y);
    }
    

    例题

    4101 银河英雄传说
    显然,最后的战舰排列是一条条“链”,即一种特殊形态的树,我们可以用并查集维护。
    本题中需要我们维护一些并查集森林中的特殊信息。具体地说,我们需要对于每个元素x,它到所在树的树根的距离d[x]和以它为根的子树的size[x](注意:由于路径压缩的存在,我们得到的并查集森林中并不是一条条的“链”),这样若x元素和y元素在同一列,那么他们之间的距离(不包括x和y)就是|d[y]-d[x]|-1
    d[x]和size[x]的具体计算过程代码如下

    int get(int x) {
        if (x == fa[x]) return x;
        int root = get(fa[x]);  
        d[x] += d[fa[x]];       
        return fa[x] = root;    
    }
    
    void merge(int x, int y) {
        x = get(x), y = get(y);
        fa[x] = y, d[x] = size[y];
        size[y] += size[x];
    }
    

    如果我们面对的问题中需要维护明显对立的两/三/更多的集合,同时“传递关系”可以相互导出时,那么我们会使用扩展域并查集。
    例如,若只有两个集合“敌人”和“朋友”,且如下关系可以相互导出:
    1 A和B是朋友,B和C是朋友,可推出A和C是朋友
    2 A和B是敌人,B和C是敌人,可推出A和C是朋友
    那么我们可以建立一个扩展域并查集来维护关系。
    具体地说,我们将一个节点x拆成两个节点x_{1}x_{2}x_{1}表示x是朋友,x_{2}表示x是敌人,同理,我们将另一个节点y拆成y_{1}y_{2}
    若题目输入x和y是朋友,那么这意味着两条信息:
    1 x_{1}y_{1}是朋友
    2 x_{2}y_{2}是朋友
    因此合并(x_{1}y_{1})和(x_{2}y_{2}
    若题目输入x和y是敌人,那么这意味着两条信息:
    1 x_{1}y_{2}是朋友
    2 x_{2}y_{1}是朋友
    因此合并(x_{1}y_{2})和(x_{2}y_{1}

    例题

    POJ1733 Parity Game
    用sum表示S序列的前缀和,则输入分两种情况
    1 S[l...r]有偶数个1,即sum[r]和sum[l-1]奇偶性相同
    2 S[l...r]有奇数个1,即sum[r]和sum[l-1]奇偶性不同
    显然,奇数域偶数域天然对立,也满足刚才“朋友”和“敌人”集合的特点:
    1 A和B奇偶性相同,B和C奇偶性相同,可推出A和C奇偶性相同
    2 A和B奇偶性不同,B和C奇偶性不同,可推出A和C奇偶性相同
    因此按照上面的方法处理输入即可。
    PS:由于l和r的范围可能很大,需要离散化处理
    代码如下

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=5000*4+5;
    const int INF=0x3f3f3f3f;
    int temp,n,f[maxn],q,x[maxn],y[maxn],a[maxn],cnt=0;
    string s[maxn];
    int find(int x){
        return f[x]==x?x:f[x]=find(f[x]);
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("Parity Game.in","r",stdin);
        cin>>temp;
        cin>>q;
        for(int i=1; i<=q; i++) {
            cin>>x[i]>>y[i]>>s[i];
            if(y[i]<x[i]) swap(x[i],y[i]);
            a[++cnt]=x[i];
            a[++cnt]=y[i];
        }
        sort(a+1,a+cnt+1);
        n=unique(a+1,a+cnt+1)-a-1;
        for(int i=1;i<=n*2;i++){
            f[i]=i;
        }
        for(int i=1; i<=q; i++) {
            int xx=lower_bound(a+1,a+n+1,x[i]-1)-a;
            int yy=lower_bound(a+1,a+n+1,y[i])-a;
            int fa=find(xx);
            int fa1=find(xx+n);
            int fb=find(yy);
            int fb1=find(yy+n);
            if(s[i]=="even") {
                if(fa==fb1){
                    cout<<i-1<<endl;
                    return 0;
                }
                f[fa]=fb;
                f[fa1]=fb1;
            }
            else{
                if(fa==fb){
                    cout<<i-1<<endl;
                    return 0;
                }
                f[fa1]=fb;
                f[fb1]=fa;
            }
        }
        cout<<q<<endl;
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    POJ1182 食物链
    这里的两个对立集合变成了三个,但原理与上面相同。
    代码如下

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    using namespace std;
    int n,k,ans=0,fa[3*5*10002],q,x,y;
    void init()
    {
        for(int i=1;i<=3*n;i++) fa[i]=i;
    }
    int f(int x)
    {
        if (x==fa[x]) return x;
        return fa[x]=f(fa[x]);
    }
    bool check1(int x,int y)
    {
        if (f(x+n)==f(y))
            return false;
        if (f(x)==f(y+n))
            return false;
        return true;
    }
    bool check2(int x,int y)
    {
        if (x==y) return false;
        if (f(x)==f(y))
            return false;
        if (f(x)==f(y+n))
            return false;
        return true;
    }
    int main()
    {
        scanf("%d%d",&n,&k);
        init();
        for(int i=1;i<=k;i++)
        {
            scanf("%d%d%d",&q,&x,&y);
            if (x>n||y>n)
            {
                ans++;
                continue;
            }
            if (q==1)
            {
                if (check1(x,y))
                {
                    fa[f(x)]=f(y);
                    fa[f(x+n)]=f(y+n);
                    fa[f(x+2*n)]=f(y+2*n);
                }
                else
                    ans++;
            }
            if (q==2)
            {
                if (check2(x,y))
                {
                    fa[f(x+n)]=f(y);
                    fa[f(x)]=f(y+2*n);
                    fa[f(x+2*n)]=f(y+n);
                }
                else
                    ans++;
            }
        }
        printf("%d",ans);
        return 0;
    }
    

    树状数组

    树状数组所有单次操作时间复杂度:O(nlogn)
    由于在设计时消去了冗余节点,并且将节点编号和二进制联系起来,树状数组的空间复杂度常数代码量都比线段树有了明显的优化。然而,这些节点的消去,也让树状数组难以维护求最值的操作
    基本功能:单点增加,区间求和,逆序对
    代码如下

    void add(int x, int y) { //单点增加
        for (; x <= N; x += x & -x) c[x] += y;
    }
    int ask(int x) { //区间求和
        int ans = 0;
        for (; x; x -= x & -x) ans += c[x];
        return ans;
    }
    for (int i = n; i; i--) { //逆序对,若a[i]范围较大,需要离散化
    //然而离散化需要排序,所以还不如直接用归并排序求逆序对
        ans += ask(a[i]-1);
        add(a[i], 1);
    }
    

    PS:树状数组求逆序对原理:逆序扫描整个序列,每次先查询小于a[i]的数的个数累加进入答案,然后将a[i]插入树状数组。这就保证了每次查询的范围都是i之后的数字值小于a[i]不包括a[i]

    例题

    4201 楼兰图腾
    记l[i]表示1...i-1中比a[i]大的数的个数
    r[i]表示i+1...n中比a[i]大的数的个数
    则“V”型图腾的数量就是\sum_{i=1}^{n}l[i]*r[i]
    另一种图腾计算同理
    代码如下

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=200000+5;
    const int INF=0x3f3f3f3f;
    ll n,a[maxn],c[maxn],l1[maxn],l2[maxn],r1[maxn],r2[maxn],ans1=0,ans2=0;
    void add(ll x){
        for(int i=x;i<=n;i+=i&-i){
            c[i]++;
        }
    }
    ll sum(ll x){
        ll sum=0;
        for(int i=x;i>0;i-=i&-i){
            sum+=c[i];
        }
        return sum;
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("楼兰图腾.in","r",stdin);
        cin>>n;
        for(int i=1;i<=n;i++){
            cin>>a[i];
            a[i]=n-a[i]+1;
        }
        memset(c,0,sizeof(c));
        for(int i=1;i<=n;i++){
            l1[i]=sum(a[i]);
            add(a[i]);
        }
        memset(c,0,sizeof(c));
        for(int i=n;i>=1;i--){
            r1[i]=sum(a[i]);
            add(a[i]);
        }
        for(int i=1;i<=n;i++){
    //      cout<<l1[i]<<" "<<r1[i]<<endl;
            ans1+=l1[i]*r1[i];
        }
        cout<<ans1<<" ";
        for(int i=1;i<=n;i++){
            a[i]=n-a[i]+1;
        }
        memset(c,0,sizeof(c));
        for(int i=1;i<=n;i++){
            l2[i]=sum(a[i]);
            add(a[i]);
        }
        memset(c,0,sizeof(c));
        for(int i=n;i>=1;i--){
            r2[i]=sum(a[i]);
            add(a[i]);
        }
        for(int i=1;i<=n;i++){
    //      cout<<l2[i]<<" "<<r2[i]<<endl;
            ans2+=l2[i]*r2[i];
        }
        cout<<ans2;
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    当然,通过一些技巧性的转化,我们能将树状数组的适用范围拓宽。
    我们来看看如何通过前缀和和差分的转化实现区间增加和单点查询功能。
    题意:长度为n的数列,q个操作。C l r d 表示把第l~r个数+d。Q x表示求第x个数的值,n<1e5,q<1e5
    我们知道,树状数组的基本操作可以实现单点修改,那么为了将区间修改转化为单点修改,我们用树状数组维护序列的差分数组,同时单点查询也就变成了差分数组上的前缀和操作。
    我们继续来考虑一个根据有一般性的问题
    POJ3468 A Simple Problem with Integers
    题目要求实现区间修改和区间求和
    为了将区间修改转化为单点修改,我们仍用树状数组b维护序列变化的差分数组,具体地说,b数组初始全部为0,若输入指令C l r d,则令b[l]+d,b[r+1]-d。
    那么l~r上新增的量就是\sum_{i=l}^{r}\sum_{j=1}^{i}b[j]=\sum_{i=1}^{r}\sum_{j=1}^{i}b[j]-\sum_{i=1}^{l-1}\sum_{j=1}^{i}b[j]
    考虑如何求\sum_{i=1}^{x}\sum_{j=1}^{i}b[j]
    将上式展开\sum_{i=1}^{x}\sum_{j=1}^{i}b[j]=(b[1])+(b[1]+b[2])+(b[1]+b[2]+b[3])+...+(b[1]+b[2]+...+b[x])\\=(x)*b[1]+(x-1)*b[2]+...+(x-(x-1))*b[x]\\=\sum_{i=1}^{x}(x-i+1)*b[i]
    此时我们将三个变量x,i,j消去变量j,变成了两个变量。
    由于求和式中同时包含x和i,不易计算。而树状数组擅于计算\sum_{i=1}^{x}b[i],因此,我们将x和i分离,即将上式变形如下
    \sum_{i=1}^{x}(x-i+1)*b[i]=(x+1)\sum_{i=1}^{x}b[i]-\sum_{i=1}^{x}i*b[i]
    因此我们建立两个树状数组c_{0}c_{1},分别用于维护b[i]和i×b[i]。
    具体地说,若输入指令C l r d,则执行以下四个操作
    1 c_{0}的l位置+d
    2 c_{0}的r+1位置-d
    3 c_{1}的l位置+l×d
    4 c_{1}的r+1位置-(r+1)×d
    用sum数组维护原序列的前缀和,对于指令Q l r,只需要输出sum[r]+(r+1)*ask(c_{0},r)-ask(c_{1},r)-(sum[l-1]+(l)*ask(c_{0},l-1)-ask(c_{1},l-1))
    代码如下

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=100000+5;
    const int INF=0x3f3f3f3f;
    ll c[2][maxn],n,q,a[maxn],sum[maxn];
    void add(int k,int x,ll v){
        for(int i=x;i<=n;i+=i&-i){
            c[k][i]+=v;
        }
    }
    ll ask(int k,int x){
        ll ans=0; 
        for(int i=x;i>0;i-=i&-i){
            ans+=c[k][i];
        }
        return ans;
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("A Simple Problem with Integers.in","r",stdin);
        cin>>n>>q;
        for(int i=1;i<=n;i++){
            cin>>a[i];
            sum[i]=sum[i-1]+a[i];
        }
        char op;
        int l,r;
        while(q--){
            cin>>op>>l>>r;
            ll d;
            if(op=='C'){
                cin>>d;
                add(0,l,d);
                add(0,r+1,-d);
                add(1,l,(ll)l*d);
                add(1,r+1,-(ll)(r+1)*d);
            }
            else{
                cout<<sum[r]-sum[l-1]+(r+1)*ask(0,r)-ask(1,r)-(l*ask(0,l-1)-ask(1,l-1))<<endl;
            }
        }
    
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    POJ2182 Lost Cows
    因为所有奶牛身高不同且是1~n的排列,所以逆序考虑每一头奶牛,若最后一头奶牛前面有A_{n}头奶牛比它矮,那么它的身高为H_{n}=A_{n}+1
    考虑倒数第二头奶牛,若它前面有A_{n-1}头奶牛比它矮,那么:
    1 若A_{n-1}<A_{n},则它的身高H_{n-1}=A_{n-1}+1
    2 若A_{n-1}\geq A_{n},则它的身高H_{n-1}=A_{n-1}+2
    以此类推,若第k头奶牛前面有A_{k}头奶牛比它矮,那么它的身高H_{k}就是1~n中第A_{k}+1小的,且没有在\left\{H_{k+1},H_{k+2},...,H_{n} \right\}中出现过的数。
    具体实现中,我们建立一个01序列b,初始全为1,表示所有身高都没有出现过。逆序扫描n头奶牛,对于每头奶牛执行如下操作:
    1 查询b中第A_{i}+1个1在哪里,其位置就是第i头奶牛的身高H_{i}
    2 将b[H_{i}]变成0(减去1)
    方法一:树状数组+二分
    每次二分答案,通过树状数组求前缀和寻找第A_{i}+1个1
    时间复杂度:O(nlog^{2}n)
    方法二:树状数组+倍增
    因为树状数组天然维护了区间长度为2的整数次幂的信息,于是可以在O(nlogn)的时间内求解
    附上两种方法的代码

    /*
    
    */
    #define method_2
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=8000+5;
    const int INF=0x3f3f3f3f;
    ll n,c[maxn],a[maxn],sum=0,ans[maxn];
    void add(ll x,ll val){
        for(int i=x;i<=n;i+=i&-i){
            c[i]+=val;
        }
    }
    ll ask(ll x){
        ll sum=0;
        for(int i=x;i>0;i-=i&-i){
            sum+=c[i];
        }
        return sum;
    }
    int main() {
        ios::sync_with_stdio(false);
        freopen("Lost Cows.in","r",stdin);
        memset(c,0,sizeof(c));
        cin>>n;
    //  sum=(1+n)*n/2;
        a[1]=0;
        for(int i=1;i<=n;i++){
            add(i,1);
        }
        for(int i=2;i<=n;i++){
            cin>>a[i];
        }
        for(int i=n;i>=1;i--){
            int l=1,r=n;
            while(l<=r){
                int mid=(l+r)>>1;
                if(ask(mid)<a[i]+1) l=mid+1;
                else r=mid-1;
            }
            ans[i]=l;
            add(l,-1);
    //      sum-=ans[i];
        }
        for(int i=1;i<=n;i++){
            cout<<ans[i]<<endl;
        }
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=8000+5;
    const int INF=0x3f3f3f3f;
    ll n,c[maxn],a[maxn],sum=0,ans[maxn];
    void add(ll x,ll val){
        for(int i=x;i<=n;i+=i&-i){
            c[i]+=val;
        }
    }
    ll ask(ll x){
        ll sum=0;
        for(int i=x;i>0;i-=i&-i){
            sum+=c[i];
        }
        return sum;
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("Lost Cows.in","r",stdin);
        memset(c,0,sizeof(c));
        cin>>n;
    //  sum=(1+n)*n/2;
        a[1]=0;
        for(int i=1;i<=n;i++){
            add(i,1);
        }
        for(int i=2;i<=n;i++){
            cin>>a[i];
        }
        for(int i=n;i>=1;i--){
            ll ans1=0,sum=0;
            for(int j=log2(n);j>=0;j--){
                if(ans1+(1<<j)<=n&&sum+c[ans1+(1<<j)]<a[i]+1) sum+=c[ans1+(1<<j)],ans1+=(1<<j);
            }
            ans[i]=ans1+1;
            add(ans1+1,-1);
        }
        for(int i=1;i<=n;i++){
            cout<<ans[i]<<endl;
        }
        return 0;
    }
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    线段树

    线段树的基本操作包括区间求和,区间求最值(这里的最值是广义的最值,不仅限于max和min),单点修改,单点查值。这些操作的时间复杂度均为O(nlogn)
    类型定义

    struct SegmentTree {
        int l, r;
        int dat;
    } t[SIZE * 4]; 
    

    建树

    void build(int p, int l, int r) {
        t[p].l = l, t[p].r = r;     
                   if (l == r) { t[p].dat = a[l]; return; } 
                   int mid = (l + r) / 2; 
        build(p*2, l, mid); 
        build(p*2+1, mid+1, r); 
        t[p].dat = max(t[p*2].dat, t[p*2+1].dat);
    }
    
    build(1, 1, n); 
    

    单点修改

    void change(int p, int x, int v) {
        if (t[p].l == t[p].r) { t[p].dat = v; return; } 
                  int mid = (t[p].l + t[p].r) / 2;
        if (x <= mid) change(p*2, x, v); 
        else change(p*2+1, x, v); 
        t[p].dat = max(t[p*2].dat, t[p*2+1].dat);
    }
    
    change(1, x, v); 
    

    区间求最值

    int ask(int p, int l, int r) {
        if (l <= t[p].l && r >= t[p].r) return t[p].dat; 
        int mid = (t[p].l + t[p].r) / 2;
        int val = 0;
        if (l <= mid) val = max(val, ask(p*2, l, r)); 
        if (r > mid) val = max(val, ask(p*2+1, l, r)); 
        return val;
    }
    
    cout << ask(1, l, r) << endl; 
    

    例题

    4301 Can you answer on these queries III
    考虑最大子段和的生成过程,在区间上维护六个信息:区间和sum,区间最大子段和dat,紧贴左端点的最大子段和lmax,紧贴右端点的最大子段和rmax。
    分三类讨论,得到转移如下
    tree[p].sum=tree[p<<1].sum+tree[p<<1|1].sum;
    tree[p].lmax=max(tree[p<<1].lmax,tree[p<<1].sum+tree[p<<1|1].lmax);
    tree[p].rmax=max(tree[p<<1|1].rmax,tree[p<<1|1].sum+tree[p<<1].rmax);
    tree[p].dat=max(tree[p<<1].dat,max(tree[p<<1|1].dat,tree[p<<1].rmax+tree[p<<1|1].lmax));
    代码如下

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    using namespace std;
    typedef long long ll;
    const int maxn=500000+5;
    const int INF=0x3f3f3f3f;
    int n,m,a[maxn];
    struct node {
        int l,r,sum,lmax,rmax,dat;
        node(int tmp=0) {
            sum=lmax=rmax=dat=tmp;
        }
    } tree[maxn<<2];
    void build(int p,int l,int r) {
        tree[p].l=l,tree[p].r=r;
        if(l==r) {
            tree[p].sum=tree[p].dat=tree[p].lmax=tree[p].rmax=a[l];
            return;
        }
        int mid=(l+r)>>1;
        build(p<<1,l,mid);
        build(p<<1|1,mid+1,r);
        tree[p].sum=tree[p<<1].sum+tree[p<<1|1].sum;
        tree[p].lmax=max(tree[p<<1].lmax,tree[p<<1].sum+tree[p<<1|1].lmax);
        tree[p].rmax=max(tree[p<<1|1].rmax,tree[p<<1|1].sum+tree[p<<1].rmax);
    //  tree[p].dat=max(tree[p<<1].dat,tree[p<<1|1].dat);
        tree[p].dat=max(tree[p<<1].dat,max(tree[p<<1|1].dat,tree[p<<1].rmax+tree[p<<1|1].lmax));
    }
    void change(int p,int x,int v) {
        if(tree[p].l==tree[p].r) {
            tree[p].sum=tree[p].dat=tree[p].lmax=tree[p].rmax=v;
            return;
        }
        int mid=(tree[p].l+tree[p].r)>>1;
        if(x<=mid) change(p<<1,x,v);
        else change(p<<1|1,x,v);
        tree[p].sum=tree[p<<1].sum+tree[p<<1|1].sum;
        tree[p].lmax=max(tree[p<<1].lmax,tree[p<<1].sum+tree[p<<1|1].lmax);
        tree[p].rmax=max(tree[p<<1|1].rmax,tree[p<<1|1].sum+tree[p<<1].rmax);
        tree[p].dat=max(tree[p<<1].dat,max(tree[p<<1|1].dat,tree[p<<1].rmax+tree[p<<1|1].lmax));
    }
    node ask(int p,int l,int r) {
        if(l<=tree[p].l&&r>=tree[p].r) {
            return tree[p];
        }
        int mid=(tree[p].l+tree[p].r)>>1;
    //  int val=-INF;
        node a,b,c;
        if(r<=mid) return ask(p<<1,l,r);
        if(l>=mid+1) return ask(p<<1|1,l,r);
        else{
            a=ask(p<<1,l,r);
        b=ask(p<<1|1,l,r);
        c.sum=a.sum+b.sum;
        c.lmax=max(a.lmax,a.sum+b.lmax);
        c.rmax=max(b.rmax,b.sum+a.rmax);
        c.dat=max(a.dat,max(b.dat,a.rmax+b.lmax));
        return c;
        }
    }
    void print(int p,int l,int r) {
        if(l==r) cout<<p<<" "<<tree[p].lmax<<" "<<tree[p].rmax<<" "<<tree[p].sum<<" "<<tree[p].dat<<endl;
        int mid=(l+r)>>1;
        print(p<<1,l,mid);
        print(p<<1|1,mid+1,r);
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("Can you answer on these queries III.in","r",stdin);
        cin>>n>>m;
        for(int i=1; i<=n; i++) {
            cin>>a[i];
        }
        build(1,1,n);
        int x,y,z;
        while(m--)
        {
            cin>>x>>y>>z;
    //      if(y>z) swap(y,z); //不能在这里swap 会影响命令2 
            if(x==1)cout<<ask(1,min(y,z),max(y,z)).dat<<endl;
            else change(1,y,z);
        }
        return 0;
    }
    #endif
    

    4302 Interval GCD
    因为gcd(x,y,z)=gcd(x,y-x,z-y)
    所以用线段树维护原序列的差分序列的gcd。
    又因为求区间gcd的时候,需要区间第一个数的原始值,所以额外用一个树状数组维护原值。
    代码如下

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    #include<iomanip>
    #define D(x) cout<<#x<<" = "<<x<<"  "
    #define E cout<<endl
    using namespace std;
    typedef long long ll;
    typedef pair<int,int>pii;
    const int maxn=5e5+5;
    const int INF=0x3f3f3f3f;
    ll n,m;
    struct SegmentTree {
        int l,r;
        ll dat;
    } t[maxn<<2];
    ll a[maxn],b[maxn],c[maxn];
    ll gcd(ll x,ll y) {
        return !y?x:gcd(y,x%y);
    }
    void build(int rt,int l,int r) {
        t[rt].l=l;
        t[rt].r=r;
        if(l==r) {
            t[rt].dat=b[l];
            return;
        }
        int mid=l+r>>1;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
        t[rt].dat=gcd(t[rt<<1].dat,t[rt<<1|1].dat);
    }
    void change(int rt,int x,ll v) {
        if(t[rt].l==t[rt].r) {
            t[rt].dat+=v;
            return;
        }
        int mid=t[rt].l+t[rt].r>>1;
        if(mid>=x) change(rt<<1,x,v);
        else change(rt<<1|1,x,v);
        t[rt].dat=gcd(t[rt<<1].dat,t[rt<<1|1].dat);
    }
    ll ask(int rt,int l,int r) {
        if(t[rt].l>=l&&t[rt].r<=r) return abs(t[rt].dat);
        int mid=t[rt].l+t[rt].r>>1;
        ll val=0;
        if(l<=mid) val=gcd(val,ask(rt<<1,l,r));
        if(r>mid) val=gcd(val,ask(rt<<1|1,l,r));
        return abs(val);
    }
    void add(int x,ll v) {
        for(int i=x; i<=n; i+=i&-i) {
            c[i]+=v;
        }
    }
    ll sum(int x) {
        ll ans=0;
        for(int i=x; i>=1; i-=i&-i) {
            ans+=c[i];
        }
        return ans;
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("Interval GCD.in","r",stdin);
        cin>>n>>m;
        for(int i=1; i<=n; i++) {
            cin>>a[i];
        }
        b[1]=a[1];
        for(int i=2; i<=n; i++) {
            b[i]=a[i]-a[i-1];
        }
        build(1,1,n);
    //  cout<<t[1].l<<" "<<t[1].r<<endl;
        char op;
        int l,r;
        ll d;
        while(m--) {
            /*
            for(int i=1;i<=n;i++){
                D(a[i]+sum(i));
            }
            E;
            */
            cin>>op;
            if(op=='Q') {
                cin>>l>>r;
                ll al=a[l]+sum(l);
                ll val=l<r?ask(1,l+1,r):0;
                cout<<gcd(ask(1,l+1,r),al)<<endl;
            } else {
                cin>>l>>r>>d;
                change(1,l,d);
                if(r<n) change(1,r+1,-d);
                add(l,d);
                if(r<n) add(r+1,-d);
            }
        }
        return 0;
    }
    #endif
    

    通过“懒惰标记/延迟标记”的技巧,线段树也能够实现O(nlogn)内完成区间修改和区间查询。

    例题

    4301 Can you answer on these queries III
    关于延迟标记这里不做解释,代码如下
    类型定义

    struct SegmentTree{
        int l,r;
        long long sum,add;
        #define l(x) tree[x].l
        #define r(x) tree[x].r
        #define sum(x) tree[x].sum
        #define add(x) tree[x].add
    }tree[SIZE*4];
    

    建树

    void build(int p,int l,int r)
    {
        l(p)=l,r(p)=r;
        if(l==r) { sum(p)=a[l]; return; }
        int mid=(l+r)/2;
        build(p*2,l,mid);
        build(p*2+1,mid+1,r);
        sum(p)=sum(p*2)+sum(p*2+1);
    }
    

    传递标记

    void spread(int p) {
        if(add(p)) {
            sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
            sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
            add(p*2)+=add(p);
            add(p*2+1)+=add(p);
            add(p)=0;
        }
    }
    

    区间修改

    void change(int p,int l,int r,int z) {
        if(l<=l(p)&&r>=r(p)) {
            sum(p)+=(long long)z*(r(p)-l(p)+1);
            add(p)+=z;
            return;
        }
        spread(p);
        int mid=(l(p)+r(p))/2;
        if(l<=mid) change(p*2,l,r,z);
        if(r>mid) change(p*2+1,l,r,z);
        sum(p)=sum(p*2)+sum(p*2+1);
    }
    

    区间求和

    long long ask(int p,int l,int r) {
        if(l<=l(p)&&r>=r(p)) return sum(p);
        spread(p);
        int mid=(l(p)+r(p))/2;
        long long ans=0;
        if(l<=mid) ans+=ask(p*2,l,r);
        if(r>mid) ans+=ask(p*2+1,l,r);
        return ans;
    }
    

    另外,线段树可以处理多维偏序问题的一个维度。因此,线段树可以和排序结合求解二维偏序问题,可以和平衡树结合构成树套树求解三维/多维偏序问题。(类似的,树状数组结合排序可以求解某些不需要求最值的二维偏序问题,树状数组+排序+CDQ分治可以离线求解三维偏序问题)
    这里讲解线段树维护扫描线+排序离散化求解二维偏序问题

    例题

    POJ1151 Atlantis
    我们用一条竖直直线从左向右扫过整个坐标系,那么直线上被并集图形覆盖的长度只会在每个矩形的左右边界发生变化。
    因此,我们只要知道扫描线在每段上被覆盖的长度,乘以该段距离,对所有这样的值求和就是答案了。
    具体实现中,我们设每个矩形左上角为(x_{1},y_{1}),右下角为(x_{2},y_{2}),矩形的左边界四元组(x_{1},y_{1},y_{2},1),右边界四元组(x_{2},y_{1},y_{2},-1),c[i]表示扫描线上第i段被覆盖的次数。
    首先对y坐标离散化,val(y)表示y被离散化后的整数值,raw[i]表示离散化后的i对应到实际坐标中的y值。
    对于每个四元组(x,y_{1},y_{2},k),把c数组中的c[val(y_{1})],c[val(y_{1}+1)],c[val(y_{1}+2)],...,c[val(y_{2}-1)]全部+k,相当于覆盖了[y_{1},y_{2}]这个区间(因为c数组的含义是区间,所以y_{1}y_{2}的区间上限是val(y_{2}-1)而不是val(y_{2}))。
    若下一个四元组横坐标为x_{2},那么这段面积就是(x_{2}-x)*\sum_{c[i]>0}(raw[i+1]-raw[i]),对于每个四元素,朴素的在c数组上修改,时间复杂度为O(n^{2})(本题实现宽裕,这种方法已经可以AC。
    考虑用线段树维护c数组,把时间复杂度优化到O(nlogn)
    方法一:运用延迟标记实现区间修改和查询
    方法二:由于我们只关心线段树根节点(整个扫描线)上被矩形覆盖的长度,并且四元组总是成对出现,所以线段树上的区间修改也是成对的,这种情况下,没有必要向下传递延迟标记。换句话说,统计时只考虑根节点,所系不用向下传递标记
    具体实现上,我们除了维护区间左右端点外,我们在线段树的每个节点上维护两个值:节点代表区间被覆盖的长度len,节点被覆盖次数cnt,初始时两者均为0。
    对于四元组(x,y_{1},y_{2},k),我们在[val()y_{1},val(y_{2})]上区间修改。该区间被线段树划分成O(logn)个节点,我们将这些节点的cnt+=k。
    对于线段中任意一个节点[l,r],若cnt>0,则len=raw(r+1)-raw(l)(同样,由于线段树维护的是区间,所以这里是r+1,而不是r)。否则,该点的len为两个子节点的len之和(需特判叶子节点)。在一个节点的cnt被修改,以及线段树从下向上回溯时,我们这样更新len,最后根节点的len就是整个扫描线上被覆盖的长度。
    代码如下,其中method_1是朴素做法,method_2是优化做法中的方法二(spread函数和上面延迟标记中的spread函数不同,这里的spread只含有向上合并的操作)。

    /*
    
    */
    #define method_2
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    #include<iomanip>
    #define D(x) cout<<#x<<" = "<<x<<"  "
    #define E cout<<endl
    using namespace std;
    typedef long long ll;
    typedef pair<int,int>pii;
    const int maxn=200+5;
    const int INF=0x3f3f3f3f;
    int n,num,c[maxn];
    double raw[maxn];
    struct node{
        double X,Y1,Y2;
        int f;
        bool operator<(const node &h)const{
            return X<h.X;
        }
    }line[maxn];
    int val(double y){
        return lower_bound(raw+1,raw+num+1,y)-raw;
    }
    int main() {
        ios::sync_with_stdio(false);
        //freopen("Atlantis.in","r",stdin);
        int kase=0;
        while(scanf("%d",&n)&&n){
            printf("Test case #%d\n",++kase);
            memset(c,0,sizeof(c));
            memset(raw,0,sizeof(raw));
            num=0;
            double X1,Y1,X2,Y2;
            for(int i=1;i<=n;i++){
                scanf("%lf%lf%lf%lf",&X1,&Y1,&X2,&Y2);
                line[2*i-1].X=X1;
                line[2*i-1].Y1=Y1;
                line[2*i-1].Y2=Y2;
                line[2*i-1].f=1;
                line[2*i].X=X2;
                line[2*i].Y1=Y1;
                line[2*i].Y2=Y2;
                line[2*i].f=-1;
                raw[++num]=Y1;
                raw[++num]=Y2;
            }
            sort(line+1,line+2*n+1);
            sort(raw+1,raw+num+1);
            num=unique(raw+1,raw+num+1)-raw-1;
            double ans=0.0;
            for(int i=1;i<=2*n-1;i++){
                double X=line[i].X;
                double X2=line[i+1].X; 
                double Y1=line[i].Y1;
                double Y2=line[i].Y2;
        //      printf("%lf %lf %lf %lf\n",X,X2,Y1,Y2);
        //      printf("%d %d\n",val(Y1),val(Y2));
                
                int f=line[i].f;
                for(int j=val(Y1);j<=val(Y2)-1;j++){
                    c[j]+=f;
                }
                for(int j=1;j<=num-1;j++){
                    if(c[j]>0){
                //      printf("%lf %lf\n",raw[j+1],raw[j]);
                        ans+=(X2-X)*(raw[j+1]-raw[j]);
                    }
                }
            }
            printf("Total explored area: %.2lf\n\n",ans);
        }
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    #include<iomanip>
    #define D(x) cout<<#x<<" = "<<x<<"  "
    #define E cout<<endl
    using namespace std;
    typedef long long ll;
    typedef pair<int,int>pii;
    const int maxn=200+5;
    const int INF=0x3f3f3f3f;
    int n,num,c[maxn];
    double raw[maxn];
    struct node{
        double X,Y1,Y2;
        int f;
        bool operator<(const node &h)const{
            return X<h.X;
        }
    }line[maxn];
    struct SegmentTree{
        int l,r,cnt;
        double len;
    }tr[maxn<<2];
    int val(double y){
        return lower_bound(raw+1,raw+num+1,y)-raw;
    }
    void build(int rt,int l,int r){
        tr[rt].l=l,tr[rt].r=r;
        if(l==r){
            tr[rt].cnt=0;
            tr[rt].len=0;
            return;
        }
        int mid=l+r>>1;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
        tr[rt].cnt=0;
        tr[rt].len=0;
    }
    void spread(int p){
        if(tr[p].cnt>0) tr[p].len=raw[tr[p].r+1]-raw[tr[p].l];
        else if(tr[p].l==tr[p].r) tr[p].len=0; //这句话一定要有 作为叶子节点的边界 否则会合并到未知位置 
        else tr[p].len=tr[p<<1].len+tr[p<<1|1].len;
    }
    void update(int rt,int l,int r,int v){
        if(tr[rt].l>=l&&tr[rt].r<=r){
            tr[rt].cnt+=v;
            spread(rt);
            return;
        }
        int mid=tr[rt].l+tr[rt].r>>1;
        if(l<=mid) update(rt<<1,l,r,v);
        if(r>mid) update(rt<<1|1,l,r,v);
        spread(rt);
    }
    int main() {
        ios::sync_with_stdio(false);
        //freopen("Atlantis.in","r",stdin);
        int kase=0;
        while(scanf("%d",&n)&&n){
            printf("Test case #%d\n",++kase);
            memset(c,0,sizeof(c));
            memset(raw,0,sizeof(raw));
            num=0;
            double X1,Y1,X2,Y2;
            for(int i=1;i<=n;i++){
                scanf("%lf%lf%lf%lf",&X1,&Y1,&X2,&Y2);
                line[2*i-1].X=X1;
                line[2*i-1].Y1=Y1;
                line[2*i-1].Y2=Y2;
                line[2*i-1].f=1;
                line[2*i].X=X2;
                line[2*i].Y1=Y1;
                line[2*i].Y2=Y2;
                line[2*i].f=-1;
                raw[++num]=Y1;
                raw[++num]=Y2;
            }
            sort(line+1,line+2*n+1);
            sort(raw+1,raw+num+1);
            num=unique(raw+1,raw+num+1)-raw-1;
            build(1,1,num-1);
            double ans=0.0;
            for(int i=1;i<=2*n-1;i++){
                double X=line[i].X;
                double X2=line[i+1].X; 
                double Y1=line[i].Y1;
                double Y2=line[i].Y2;
                update(1,val(Y1),val(Y2)-1,line[i].f);
                ans+=(X2-X)*tr[1].len;
            }
            printf("Total explored area: %.2lf\n\n",ans);
        }
        return 0;
    }
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    POJ2482 Stars in Your Window
    由于矩形长宽确定且不可旋转,我们讨论矩形右上角的位置。
    逆向思考,我们发现,对于星星(x,y),能够圈住它的矩形的右上角的坐标范围是(x,y)到(x+w,y+h),我们称其为一个区域,区域的权值就是星星的亮度。将所有星星对应的区域全部找出来后,问题转化为若干个区域中,求在那个坐标上区域重叠的权值和最大
    求区域重叠的权值和最大,换句话说,我们需要维护一个求二维区间最值的数据结构。类比上一题的做法,我们从左至右扫描,同时关于纵坐标建立一棵支持区间修改和维护区间最值的线段树,对于四元组(x,y_{1},y_{2},c),在线段树中,将对应的区间[y_{1},y_{2}]+c即可,最后根节点的dat用于更新答案。
    PS:由于本题认为矩形边界上的星星不算,我们将矩形做缩放,即每个星星限定的区域是(x,y)到(x+w-1,y+h-1)
    代码如下(这里用延迟标记实现了区间修改)

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    #include<iomanip>
    #define D(x) cout<<#x<<" = "<<x<<"  "
    #define E cout<<endl
    using namespace std;
    typedef long long ll;
    typedef pair<int,int>pii;
    const int maxn=20000+5;
    const int INF=0x3f3f3f3f;
    struct SegmentTree{
        ll l,r;
        ll dat,add;
    }t[maxn<<2];
    struct node{
        double X,Y1,Y2;
        ll f;
        bool operator<(const node &h)const{
            return X!=h.X?X<h.X:f<h.f;
        }
    }a[maxn];
    ll num,n,W,H,b[maxn];
    void build(ll rt,ll l,ll r){
        t[rt].l=l,t[rt].r=r;
        t[rt].dat=t[rt].add=0;
        if(l==r){
            return;
        }
        ll mid=l+r>>1;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
    }
    void spread(ll rt){
        if(t[rt].add){
            t[rt<<1].dat+=t[rt].add;
            t[rt<<1].add+=t[rt].add;
            t[rt<<1|1].dat+=t[rt].add;
            t[rt<<1|1].add+=t[rt].add;
            t[rt].add=0;
        }
    }
    void change(ll rt,ll l,ll r,ll v){
        if(t[rt].l>=l&&t[rt].r<=r){
            t[rt].add+=v;
            t[rt].dat+=v;
            return;
        }
        spread(rt);
        ll mid=(t[rt].l+t[rt].r)>>1;
        if(mid>=l) change(rt<<1,l,r,v);
        if(mid<r) change(rt<<1|1,l,r,v);
        t[rt].dat=max(t[rt<<1].dat,t[rt<<1|1].dat);
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("Stars in Your Window.in","r",stdin);
        while(cin>>n>>W>>H){
            ll x,y,z;
            num=0;
            for(int i=1;i<=n;i++){
                cin>>x>>y>>z;
                a[2*i-1].Y1=a[2*i].Y1=y;
                a[2*i-1].Y2=a[2*i].Y2=y+H-1;
                a[2*i-1].X=x;
                a[2*i].X=x+W;
                a[2*i-1].f=z;
                a[2*i].f=-z;
                b[++num]=y;
                b[++num]=y+H-1; 
            }
            sort(b+1,b+num+1);
            num=unique(b+1,b+num+1)-b-1;
            for(int i=1;i<=2*n;i++){
                a[i].Y1=lower_bound(b+1,b+num+1,a[i].Y1)-b;
                a[i].Y2=lower_bound(b+1,b+num+1,a[i].Y2)-b;
            }
            sort(a+1,a+2*n+1);
            build(1,1,num);
            ll ans=0;
            for(int i=1;i<=2*n;i++){
                change(1,a[i].Y1,a[i].Y2,a[i].f);
                ans=max(ans,t[1].dat);
            }
            cout<<ans<<endl;
        }
        return 0;
    }
    #endif
    #ifdef method_2
    /*
    
    */
    
    #endif
    #ifdef method_3
    /*
    
    */
    
    #endif
    

    另外,有些计数问题中,线段树可用于维护值域(权值线段树),为了避免空间复杂度过高,权值线段树会运用动态开点的方式进行,这样它没有了完全二叉树父子节点二倍编号的特性,转而用指针变量记录左右儿子。(例题 6302 雨天的尾巴
    类型定义

    struct node1 {
        int lc,rc,dat,pos;
    } tr[maxn*20*4]; //不动态开点的话 需要的空间就是maxn*maxn 因为对于每个点都要维护1e5种类型 这里的20=log2(maxn)
    //也就是说动态开点的线段树空间大概为maxn log(maxn)*4 
    //每进行一次插入,会添加log级的点,因此开nlogn级数组即可。
    

    初始化(开始的时候,n棵权值线段树每棵线段树只有根节点)

    for(int i=1; i<=n; i++) {
        root[i]=++num;
    }
    

    插入节点(维护区间最值和区间最值取到的位置)

    void insert(int p,int l,int r,int val,int delta) {
        if(l==r) {
            tr[p].dat+=delta;
            if(tr[p].dat==0) tr[p].pos=0;
            else tr[p].pos=l;
            return;
        }
        int mid=(l+r)>>1;
        if(val<=mid) {
            if(!tr[p].lc) tr[p].lc=++num;
            insert(tr[p].lc,l,mid,val,delta);
        } else {
            if(!tr[p].rc) tr[p].rc=++num;
            insert(tr[p].rc,mid+1,r,val,delta);
        }
        tr[p].dat=max(tr[tr[p].lc].dat,tr[tr[p].rc].dat);
        tr[p].pos=tr[tr[p].lc].dat>=tr[tr[p].rc].dat?tr[tr[p].lc].pos:tr[tr[p].rc].pos;
    }
    

    线段树合并(对应计数问题中的值域合并过程)

    int merge(int p,int q,int l,int r) {
        if(!p) return q;
        if(!q) return p;
        if(l==r) {
            tr[p].dat+=tr[q].dat;
            if(tr[p].dat==0) tr[p].pos=0;
            else tr[p].pos=l;
            return p;
        }
        int mid=(l+r)>>1;
        tr[p].lc=merge(tr[p].lc,tr[q].lc,l,mid);
        tr[p].rc=merge(tr[p].rc,tr[q].rc,mid+1,r);
        tr[p].dat=max(tr[tr[p].lc].dat,tr[tr[p].rc].dat);
        tr[p].pos=tr[tr[p].lc].dat>=tr[tr[p].rc].dat?tr[tr[p].lc].pos:tr[tr[p].rc].pos;
        return p;
    }
    

    BST和平衡树

    BST的原理这里不做解释,直接上代码。
    建树

    struct BST {
        int l, r; 
        int val; 
    }a[SIZE]; 
    int tot, root, INF = 1<<30;
    
    int New(int val) {
        a[++tot].val = val;
        return tot;
    }
    
    void Build() {
        New(-INF), New(INF);
        root = 1, a[1].r = 2;
    }
    

    检索

    int Get(int p, int val) {
        if (p == 0) return 0; 
        if (val == a[p].val) return p; 
        return val < a[p].val ? Get(a[p].l, val) : Get(a[p].r, val);
    }
    

    插入

    void Insert(int &p, int val) {
        if (p == 0) {
            p = New(val); 
                    return;
        }
        if (val == a[p].val) return;
        if (val < a[p].val) Insert(a[p].l, val);
        else Insert(a[p].r, val);
    }
    

    求前驱/后继 这里以后继为例

    int GetNext(int val) {
        int ans = 2; // a[2].val==INF
        int p = root;
        while (p) {
            if (val == a[p].val) {
            if (a[p].r > 0) { 
                    p = a[p].r;
                    while (a[p].l > 0) p = a[p].l;
                    ans = p;
                }
                break;
            }
            if (a[p].val > val && a[p].val < a[ans].val) ans = p;
            p = val < a[p].val ? a[p].l : a[p].r;
        }
        return ans;
    }
    

    删除

    void Remove(int val) { //注意p是引用
        int &p = root; //由于是引用 所以需要保存root的副本
        while (p) {
            if (val == a[p].val) break;
            p = val < a[p].val ? a[p].l : a[p].r;
        }
        if (p == 0) return;
        if (a[p].l == 0) p = a[p].r; 
        else if (a[p].r == 0) p = a[p].l;   
        else { 
            int next = a[p].r;
            while (a[next].l > 0) next = a[next].l;
            Remove(a[next].val);
            a[next].l = a[p].l, a[next].r = a[p].r;
            p = next; 
        }
    }
    

    若BST维护的是随机序列,那么它的时间复杂度为O(nlogn),但是当序列有序时,它的复杂度会退化成O(n^{2})。由于满足BST性质且中序遍历相同的二叉树很多,而保证相同而形状改变的操作叫旋转。
    通过旋转,我们能让一个不平衡的二叉树逐渐变平衡。

    image
    左旋
    void zig(int &p) { //注意这里p是引用而q不是引用
        int q = a[p].l;
        a[p].l = a[q].r, a[q].r = p, p = q;
        Update(a[p].r), Update(p);
    }
    

    右旋

    void zag(int &p) {
        int q = a[p].r;
        a[p].r = a[q].l, a[q].l = p, p = q;
        Update(a[p].l), Update(p);
    }
    

    为了找到旋转的条件,我们在维护节点权值的同时,为每个节点分配一个随机生成的额外权值,时刻维护这个额外权值满足大根堆性质,若不满足时进行旋转即可。

    例题

    4601 普通平衡树
    平衡树模板题,这里不再赘述过程,详见代码
    PS:值得注意的是:
    1 可能有数值重复,这时我们给每个节点一个cnt变量表示这个节点数值出现的次数,当cnt=0时删除节点即可。
    2 同时这里需要维护排名。我们可以给每个节点维护一个size表示以该节点为根,其所有节点的cnt值之和,当不存在重复数值时,size即为子树大小。使用递归更新size即可。

    /*
    
    */
    #define method_1
    #ifdef method_1
    /*
    
    */
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<vector>
    #include<cstring>
    #include<cstdlib>
    #include<iomanip>
    #define D(x) cout<<#x<<" = "<<x<<"  "
    #define E cout<<endl
    using namespace std;
    typedef long long ll;
    typedef pair<int,int>pii;
    const int maxn=100000+5;
    const int INF=0x3f3f3f3f;
    int n;
    struct node {
        int l,r;
        int dat,val;
        int cnt,size;
    } a[maxn];
    int tot=0,root;
    int New(int val) {
        a[++tot].val=val;
        a[tot].dat=rand();
        a[tot].cnt=a[tot].size=1;
        return tot;
    }
    void Update(int p) {
        a[p].size=a[p].cnt+a[a[p].l].size+a[a[p].r].size;
    }
    void Build() {
        New(-INF*2);
        New(INF*2);
        root=1,a[1].r=2;
        Update(root);
    }
    int GetRankByVal(int p,int val) {
        if(p==0) return 0;
        if(a[p].val==val) return a[a[p].l].size+1;
        if(a[p].val<val) return GetRankByVal(a[p].r,val)+a[p].cnt+a[a[p].l].size;
        return GetRankByVal(a[p].l,val);
    }
    int GetValByRank(int p,int rank) {
        if(p==0) return INF;
        if(a[a[p].l].size>=rank) return GetValByRank(a[p].l,rank);
        if(a[a[p].l].size+a[p].cnt>=rank) return a[p].val;
        return GetValByRank(a[p].r,rank-(a[a[p].l].size+a[p].cnt));
    }
    void zig(int &p) {
        int q=a[p].l;
        a[p].l=a[q].r,a[q].r=p,p=q;
        Update(a[p].r);
        Update(p);
    }
    void zag(int &p) {
        int q=a[p].r;
        a[p].r=a[q].l,a[q].l=p,p=q;
        Update(a[p].l);
        Update(p);
    }
    void Insert(int &p,int val) {
        if(p==0) {
            p=New(val);
            return;
        }
        if(val==a[p].val) {
            a[p].cnt++,Update(p);
            return;
        }
        if(val<a[p].val) {
            Insert(a[p].l,val);
            if(a[p].dat<a[a[p].l].dat) zig(p);
        }
        if(val>a[p].val) {
            Insert(a[p].r,val);
            if(a[p].dat<a[a[p].r].dat) zag(p);
        }
        Update(p);
    }
    int GetPre(int val) {
        int ans=1;
        int p=root;
        while(p) {
            if(val==a[p].val) {
                if(a[p].l>0) {
                    p=a[p].l;
                    while(a[p].r>0) p=a[p].r;
                    ans=p;
                }
                break;
            }
            if(a[p].val<val&&a[p].val>a[ans].val) ans=p;
            p=val<a[p].val?a[p].l:a[p].r;
        }
        return a[ans].val;
    }
    int GetNext(int val) {
        int ans=2;
        int p=root;
        while(p) {
            if(val==a[p].val) {
                if(a[p].r>0) {
                    p=a[p].r;
                    while(a[p].l>0) p=a[p].l;
                    ans=p;
                }
                break;
            }
            if(a[p].val>val&&a[p].val<a[ans].val) ans=p;
            p=val<a[p].val?a[p].l:a[p].r;
        }
        return a[ans].val;
    }
    void Remove(int &p,int val) {
        if(p==0) return;
        if(val==a[p].val) {
            if(a[p].cnt>1) {
                a[p].cnt--;
                Update(p);
                return;
            }
            if(a[p].l||a[p].r) {
                if(a[p].r==0||a[a[p].l].dat>a[a[p].r].dat) { //zig之后  a[p].r是a[p].l的父节点
                    zig(p),Remove(a[p].r,val); //注意:这样传递不会改变root的值,除非对p做出赋值才会改变root
                } else {
                    zag(p),Remove(a[p].l,val);
                }
                Update(p);
            } else p=0;
            return;
        }
        val<a[p].val?Remove(a[p].l,val):Remove(a[p].r,val);
        Update(p);
    }
    int main() {
        ios::sync_with_stdio(false);
    //  freopen("普通平衡树.in","r",stdin);
        cin>>n;
        Build();
        int op,x;
        for(int i=1; i<=n; i++) {
            cin>>op>>x;
            if(op==1) {
                Insert(root,x);
            }
            if(op==2) {
                Remove(root,x);
            }
            if(op==3) {
                cout<<GetRankByVal(root,x)-1<<endl;
            }
            if(op==4) {
                cout<<GetValByRank(root,x+1)<<endl;
            }
            if(op==5) {
                cout<<GetPre(x)<<endl;
            }
            if(op==6) {
                cout<<GetNext(x)<<endl;
            }
        }
        return 0;
    }
    #endif
    

    相关文章

      网友评论

        本文标题:「数据结构进阶」例题之树形数据结构

        本文链接:https://www.haomeiwen.com/subject/mvzkeqtx.html