让我们来看一个经典的问题吧:
给定一个[1,n]的区间,m次操作,操作种类如下:
1 L R:查询[L,R]的区间和
2 L R X:将[L,R]的值加上X
这种经典问题,想必大家学过线段树后都可以轻松解决。然而如果再增加一种操作:
3 K:回退到第K次修改操作的结果
可见,如果题目要求回溯到历史版本,那么普通的线段树就不能解决了,因为在每次更新操作后,线段树存储的内容就发生了改变,如果不进行特殊记录,那么这种改变将是永久的。因此,对于这种类型的题目,我们可以用到今天要讨论的数据结构——主席树来进行解决。
主席树,严格来讲应该叫:函数式线段树,是基于线段树的一种数据结构,常用于处理一些在线问题,关于在线离线的概念参考上一篇文章:在线和离线算法。事实上,主席树就是多个线段树的集合体。
主席树的实质,就是以最初的线段树作为模板,通过"结点复用“的方式,实现存储多个线段树。
对于文章开始的问题,观察后可以发现,在2操作进行后,在上一次修改后的线段树上,最多修改O(logn)个结点(最远是从根节点到叶子节点)。如果每次单独新建一个线段树,则会造成重复存储,如图所示:
浅蓝色的结点是当前修改操作时访问的结点,白色结点为上一棵线段树的结点。
如果对每次修改操作无差别复制一棵线段树,那么用于存储节点的开销是巨大的,因为对于单次修改,大部分结点都不曾被访问修改。
通过“结点复用”的方式,我们可以将这多棵线段树压缩成如下形式:
开辟新结点 结点复用
因此第i个线段树只要通过保留除修改路径外的第i-1棵线段树的结点,再新增加至多O(logn)个结点。
rt[i]保存第i次操作的线段树的根节点,这样,回退到第k次操作等价于rt[i]=rt[k],我们的问题就迎刃而解啦。
那么,怎么来建立一棵主席树呢?针对文章开始的题目,下面给出实现步骤:
1. 创建根节点、左右儿子结点数组
int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];
tot是每次新建的结点编号。
rt[i]是第i棵线段树的根节点的编号。
lson[x]和rson[x]是结点x的左右儿子结点的编号。
v[x]是结点x代表的区间的和。
lz[x]是结点x的懒惰(lazy)值。
a[i]是初始的第i个位置的值。
因为结点每次至多更新O(logn)个,所以数组范围应该在原来的20-50倍左右。
2.区间更新的pushup和pushdown
void push_up(int x){
v[x]=v[lson[x]]+v[rson[x]];
}
void push_down(int x,int len){
if(lz[x]){
v[lson[x]]+=(len>>1)*lz[x];
v[rson[x]]+=(len-(len>>1))*lz[x];
lz[lson[x]]+=(len>>1)*lz[x];
lz[rson[x]]+=(len-(len>>1))*lz[x];
lz[x]=0;
}
}
区间更新基础,不会的可以先了解线段树的区间更新写法。
3. 建树
void build(int &x,int l,int r){
x=++tot;
lz[x]=0;
if(l==r){
v[x]=a[l];
return;
}
int mid=l+r>>1;
build(lson[x],l,mid);
build(rson[x],mid+1,r);
push_up(x);
}
和线段树的思想是一样的,只是在调用过程中,我们以引用的形式,实现对rt,lson,rson的更新。
建树的调用如下:
build(rt[0],1,n);
3. 更新
void update(int L,int R,int l,int r,int &x,int last,int val){
x=++tot;
lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
if(L<=l&&R>=r){
v[x]+=(r-l+1)*val;lz[x]+=val;
return;
}
push_down(x,r-l+1);
int mid=l+r>>1;
if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
push_up(x);
}
第1行开辟了新的结点,第2行进行了结点复用,last就是上一棵线段树的结点,从根节点向下更新。
更新的调用如下:
update(x,y,1,n,rt[i],rt[i-1],w);
4. 查询
int query(int L,int R,int l,int r,int x){
if(L<=l&&R>=r){
return v[x];
}
push_down(x,r-l+1);
int mid=l+r>>1,sum=0;
if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
push_up(x);
return sum;
}
查询就是简单的区间查询。
查询的调用如下:
query(x,y,1,n,rt[i]);
5. 实现
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <cmath>
#include <functional>
#include <map>
#include <stack>
#include <ctime>
#include <sstream>
#include <bitset>
//#include<ext/pb_ds/assoc_container.hpp>
//#include <bits/stdc++.h>
#define REP(i,j,k) for(int (i)=(j);(i)<=(k);(i)++)
#define ERP(i,j,k) for(int (i)=(j);(i)>=(k);(i)--)
#define MEM(a,b) memset(a,b,sizeof(a))
#define NE putchar('\n')
#define SP putchar(' ')
#define fi first
#define sc second
#define mkp make_pair
#define pb push_back
#define all(a) a.begin(),a.end()
//#define lson l,mid,x<<1
//#define rson mid+1,r,x<<1|1
#define lowbit(x) ((x)&(-(x)))
#define lc(a) ch[(a)][0]
#define mod_add(a,b,m) (a+b>m?a+b-m:a+b)
#define mod_sub(a,b,m) (a-b<0?a-b+m:a-b)
using namespace std;
//using namespace __gnu_pbds;
typedef double DB;
typedef long double LDB;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;
const DB eps=1e-6;
const DB Pi=acos(-1.0);
const ll mod=1e9+7;
const ull ha1=1e9+7;
const ull ha2=1e9+9;
const int maxn=1e5+10;
const int maxm=1e6+10;
const int inf=1e9+10;
//IO挂
template<typename Type>inline void read(Type&in){
in=0;Type f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){in=in*10+ch-'0';ch=getchar();}
in*=f;
}
template<typename Type>inline void out(Type o){
if(o<0){putchar('-');o=-o;}
if(o>=10) out(o/10);
putchar(o%10+'0');
}
/*Header*/
//printf("%d%c",a[i]," \n"[i==n]);
int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];
void push_up(int x){
v[x]=v[lson[x]]+v[rson[x]];
}
void push_down(int x,int len){
if(lz[x]){
v[lson[x]]+=(len>>1)*lz[x];
v[rson[x]]+=(len-(len>>1))*lz[x];
lz[lson[x]]+=(len>>1)*lz[x];
lz[rson[x]]+=(len-(len>>1))*lz[x];
lz[x]=0;
}
}
void build(int &x,int l,int r){
x=++tot;
lz[x]=0;
if(l==r){
v[x]=a[l];
return;
}
int mid=l+r>>1;
build(lson[x],l,mid);
build(rson[x],mid+1,r);
push_up(x);
}
void update(int L,int R,int l,int r,int &x,int last,int val){
x=++tot;
lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
if(L<=l&&R>=r){
v[x]+=(r-l+1)*val;lz[x]+=val;
return;
}
push_down(x,r-l+1);
int mid=l+r>>1;
if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
push_up(x);
}
int query(int L,int R,int l,int r,int x){
if(L<=l&&R>=r){
return v[x];
}
push_down(x,r-l+1);
int mid=l+r>>1,sum=0;
if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
push_up(x);
return sum;
}
int x,y,w;
int main(){
int n,k,opt;
cin>>n>>k;
for(int i=1;i<=n;i++){
cin>>a[i];
}
build(rt[0],1,n);
for(int i=1;i<=k;i++){
cin>>opt;
if(opt==1){
rt[i]=rt[i-1];
cin>>x>>y;
cout<<query(x,y,1,n,rt[i])<<endl;
}
else if(opt==2){
cin>>x>>y>>w;
update(x,y,1,n,rt[i],rt[i-1],w);
}
else{
cin>>x;
rt[i]=rt[x];
}
}
return 0;
}
对于第i个操作,方式1通过rt[i-1]更新rt[i],方式2通过引用更新rt[i],方式3通过rt[x]更新rt[i]。
6. 测试一下~
input.txt
10 8
1 2 3 4 5 6 7 8 9 10
2 6 7 2
1 6 7
2 3 5 4
1 3 5
2 1 9 5
1 1 9
3 3
1 1 10
output.txt
17
24
106
71
正确无误~(blink)
那么,主席树的入门就到这里了,下面给出poj 2104(静态区间求第K大)的主席树代码,作为参考啦~
#include <bits/stdc++.h>
#include <cstdio>
#define fi first
#define sc second
#define mkp make_pair
#define pb push_back
#define all(a) a.begin(),a.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;
const double eps=1e-8;
const double pi=acos(-1);
const int mod=1e9+7;
/*Header*/
const int maxn=1e5+10;
int rt[maxn*20],lson[maxn*20],rson[maxn*20],sum[maxn*20];
int a[maxn],b[maxn];
int tot;
int n,q;
void build(int &x,int l,int r){
x=++tot;
sum[x]=0;
if(l==r) return;
int mid=(l+r)>>1;
build(lson[x],l,mid);
build(rson[x],mid+1,r);
}
void update(int &x,int last,int l,int r,int pos){
x=++tot;
lson[x]=lson[last];
rson[x]=rson[last];
sum[x]=sum[last]+1;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) update(lson[x],lson[last],l,mid,pos);
else update(rson[x],rson[last],mid+1,r,pos);
}
int query(int L,int R,int l,int r,int k){
if(l==r) return l;
int mid=(l+r)>>1;
int cnt=sum[lson[R]]-sum[lson[L]];
if(k<=cnt) return query(lson[L],lson[R],l,mid,k);
else return query(rson[L],rson[R],mid+1,r,k-cnt);
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%d %d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b+1,b+1+n);
int m=unique(b+1,b+1+n)-(b+1);
tot=0;
build(rt[0],1,m);
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
for(int i=1;i<=n;i++) update(rt[i],rt[i-1],1,m,a[i]);
int x,y,k,ans;
while(q--){
scanf("%d %d %d",&x,&y,&k);
ans=query(rt[x-1],rt[y],1,m,k);
printf("%d\n",b[ans]);
}
}
return 0;
}
网友评论