splay学习笔记(上)

作者: MasterHZJ | 来源:发表于2018-08-20 16:53 被阅读0次

    前不久其实已经学过splay了,但是总觉得似乎不能灵活地改造它,于是重新学习了一波。

    感谢https://oi.men.ci/splay-notes-1/
    关于splay的解释这里说的也比较清楚
    下面我们分成单点版和区间版讨论

    入门题 Tyvj 1728 普通平衡树

    您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
    1. 插入x数
    2. 删除x数(若有多个相同的数,因只删除一个)
    3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
    4. 查询排名为x的数
    5. 求x的前驱(前驱定义为小于x,且最大的数)
    6. 求x的后继(后继定义为大于x,且最小的数)

    #include <bits/stdc++.h>
    #define rep(i, a, b) for (int i=a; i<=b; i++)
    #define drep(i, a, b) for (int i=b; i>=a; i--)
    typedef long long LL;
    using namespace std;
    
    struct Splay {
        struct node {
            node *fa, *ch[2], **root;
            int x, size, cnt;
            node(node **root, node *fa, int x): root(root), fa(fa), x(x), cnt(1), size(1) {
                ch[0]=ch[1]=NULL;
            }
            inline int relation() {
                return this == fa->ch[0] ? 0 : 1;
            }
            inline void maintain() {
                size = cnt;
                if (ch[0]) size += ch[0]->size;
                if (ch[1]) size += ch[1]->size;
            }
    
            void rotate() {
                node *old=fa;
                int r=relation();
    
                fa=old->fa;
                if (old->fa) old->fa->ch[old->relation()]=this;
                if (ch[r^1]) ch[r^1]->fa=old;
                old->ch[r]=ch[r^1];
                old->fa=this;
                ch[r^1]=old;
    
                old->maintain();
                maintain();
                if (fa==NULL) *root=this;
            }
            void splay(node *target=NULL) {
                while (fa!=target) {
                    if (fa->fa==target) rotate();
                    else if (fa->relation()==relation()) {
                        fa->rotate();
                        rotate();
                    }else {
                        rotate();
                        rotate();
                    }
                }
            }
            node *pred() {
                node *v=ch[0];
                while (v->ch[1]) v=v->ch[1];
                return v;
            }//前驱precursor
            node *succ() {
                node *v = ch[1];
                while (v->ch[0]) v=v->ch[0];
                return v;
            }
    
            int rank() {
                return ch[0] ? ch[0]->size : 0;
            }
        } *root;
        Splay():root(NULL) {
            insert(INT_MAX);
            insert(INT_MIN);
        }
        node *insert(int x) {
            node **v = &root, *fa=NULL;
            while (*v!=NULL && (*v)->x!=x) {
                fa=*v;
                fa->size++;
                if (x<fa->x) v=&fa->ch[0]; else v=&fa->ch[1];
            }
            if (*v!=NULL) {
                (*v)->cnt++;
                (*v)->size++;
            }else (*v) = new node(&root, fa, x);
            (*v)->splay();
            return root;
        }
        node *find(int x) {
            node *v=root;
            while (v!=NULL && v->x != x) if (x<v->x) v=v->ch[0]; else v=v->ch[1];
            if (v) v->splay();
            return v;
        }
        void erase(node *v) {
            node *pred=v->pred(), *succ=v->succ();
            pred->splay();
            succ->splay(pred);
            if (v->size>1) {
                v->size--, v->cnt--;
            }else {
                delete succ->ch[0];
                succ->ch[0]=NULL;
            }
            succ->size--, pred->size--;
        }
        void erase(int x) {
            node *v=find(x);
            if (!v) return;
            erase(v);
        }
        int pred(int x) {
            node *v = find(x);
            if (v==NULL) {
                v=insert(x);
                int res=v->pred()->x;
                erase(v);
                return res;
            }else return v->pred()->x;
        }
        int succ(int x) {
            node *v=find(x);
            if (v==NULL) {
                v=insert(x);
                int res=v->succ()->x;
                erase(v);
                return res;
            }else return v->succ()->x;
        }
        int rank(int x) {
            node *v=find(x);
            if (v==NULL) {
                v=insert(x);
                int res=v->rank();
                erase(v);
                return res;
            }else return v->rank();
        }
        int select(int k) {
            node *v = root;
            while (!(k >= v->rank() && k < v->rank() + v->cnt)){
                if (k<v->rank()) v=v->ch[0]; else {
                    k-=v->rank()+v->cnt;
                    v=v->ch[1];
                }
            }
            v->splay();
            return v->x;
        }
    }splay;
    
    int main() {
        int n;
        scanf("%d", &n);
        while (n--) {
            int opt, x;
            scanf("%d %d", &opt, &x);
            if (opt==1) splay.insert(x);
            if (opt==2) splay.erase(x);
            if (opt==3) printf("%d\n", splay.rank(x));
            if (opt==4) printf("%d\n", splay.select(x));
            if (opt==5) printf("%d\n", splay.pred(x));
            if (opt==6) printf("%d\n", splay.succ(x));
        }
        return 0;
    }
    

    如果是打acm的话不太懂splay的原理其实没有太大关系,这个板子已经把splay的基本操作封装再node结构体里面了,可以理解成splay是一个一直在维护平衡的名次树。所以起码要理解名次树的性质:

    左子树的值<根节点的值<右子树的值

    上面的splay只支持单点操作,其实用线段树也可以实现(强烈推荐zkw的论文《统计的力量》,很精妙),下面我们来讨论区间操作的splay。

    splay的区间操作对比线段树/数状数组,支持:

    • 区间删除
    • 区间翻转

    区间splay重要的操作是选择区间,比如要对区间[l,r]进行操作,我们要做的是将节点 l-1 Splay到根,再讲节点 r-1 splay到根节点的右儿子,那么根节点的右儿子的左子树就是区间[l, r] (根据 左子树的值<根节点的值<右子树的值 的性质)

    其它区间求和,区间最小值,区间修改之类的类似与线段树,通过lazy标记来实现

    我们来看一下这题

    Tyvj1729 文艺平衡树

    您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:翻转一个区间,例如原有序序列是5 4 3 2 1,翻转区间是[2,4]的话,结果是5 2 3 4 1

    需要注意的是,在上一题中,我们节点的权值是数的大小,在这一题中,我们的节点的权值是数的位置

    #include <bits/stdc++.h>
    #define rep(i, a, b) for (int i=a; i<=b; i++)
    #define drep(i, a, b) for (int i=b; i>=a; i--)
    typedef long long LL;
    using namespace std;
    
    template <typename T>
    struct Splay {
        struct node{
            node *ch[2], *parent, **root;
            T value;
            int size;
            bool bound, reverse;
            node(node *parent, node **root, const T &value, bool bound=false, bool reverse=false):parent(parent), root(root), value(value), reverse(false), size(1), bound(bound){
                ch[0]=ch[1]=NULL;
            }
            ~node(){
                if (ch[0]) delete(ch[0]);
                if (ch[1]) delete(ch[1]);
            }
            inline int relation(){return this==parent->ch[0]?0:1;}
            inline int lsize(){return ch[0] ? ch[0]->size : 0;}
            inline int rsize(){return ch[1] ? ch[1]->size : 0;}
            inline void maintain(){size = lsize() + rsize() +1;}
            inline node *grandparent(){return !parent ? NULL : parent->parent;}
            void *pushdown(){
                if (reverse){
                    //swap(ch[0], ch[1]);
                    node *tmp=ch[0];
                    ch[0]=ch[1];
                    ch[1]=tmp;
                    if (ch[0]) ch[0]->reverse^=1;
                    if (ch[1]) ch[1]->reverse^=1;
                    reverse = false;
                }
            }
            void rotate(){
                parent->pushdown(), pushdown();
                node *old=parent;
                int x=relation();
                if (grandparent()) grandparent()->ch[old->relation()] = this;
                parent=grandparent();
                old->ch[x] = ch[x^1];
                if (ch[x^1]) ch[x^1]->parent = old;
                ch[x^1]=old;
                old->parent=this;
                old->maintain(), maintain();
                if (!parent) *root=this;
            }
            node *splay(node **target=NULL){
                if (!target) target=root;
                while (this!=*target){
                    parent->pushdown();
                    if (parent == *target) rotate();
                    else if (parent->relation() == relation()) parent->rotate(), rotate();
                    else rotate(), rotate();
                }
                return *target;
            }
        }*root;
        ~Splay() {
            if (root) delete root;
        }
        void build(const T *a, int n){
            root = build(a, 1, n, NULL);
            rep(i, 0, 1){
                node *bound_parent=NULL, **bound=&root;
                while (*bound){
                    bound_parent = *bound;
                    bound_parent->size++;
                    bound = &(*bound)->ch[i];
                }
                *bound=new node(bound_parent, &root, 0, true);
            }
        }//插入边界值
        node *build (const T *a, int l, int r, node *parent){
            if (l>r) return NULL;
            int mid=(l+r)>>1;
            node *v=new node(parent, &root, a[mid-1]);
            v->ch[0] = build(a, l, mid - 1, v);
            v->ch[1] = build(a, mid + 1, r, v);
            v->maintain();
            return v;
        }
        node *select(int k) {
            k++;
            node *v = root;
            while (v->pushdown(), k != v->lsize() + 1)
                if (k < v->lsize() + 1) v = v->ch[0]; else k -= v->lsize() + 1, v = v->ch[1];
            return v->splay();
        }
        node *&select(int l, int r) {
            node *lbound=select(l-1), *rbound=select(r+1);
            lbound->splay();
            rbound->splay(&lbound->ch[1]);
            return rbound->ch[0];
        }
        void reverse(int l, int r) {
            node *range = select(l, r);
            range->reverse ^= 1;
        }
        void fetch(T *a) {
            dfs(a, root);
        }
        void dfs(T *&a, node *v) {
            if (v) {
                v->pushdown();
                dfs(a, v->ch[0]);
                if (!v->bound) *a++=v->value;
                dfs(a, v->ch[1]);
            }
        }
    };
    Splay<int>splay;
    const int MAXN=101000;
    int n, m;
    int a[MAXN];
    int main() {
        scanf("%d%d", &n, &m);
        for (int i=0; i<n; i++) a[i]=i+1;
        splay.build(a, n);
    
        for (int i=0; i<m; i++) {
            int l, r;
            scanf("%d%d", &l, &r);
            splay.reverse(l, r);
        }
        splay.fetch(a);
        for (int i=0; i<n; i++) printf("%d ", a[i]);
        return 0;
    }
    

    相关文章

      网友评论

        本文标题:splay学习笔记(上)

        本文链接:https://www.haomeiwen.com/subject/ypmliftx.html