题目大意
给定一棵 N 个节点的树,每个点有各自的点权 pi ,但是边的权值都是 1 (也可以认为没有边权)。保证所有点的点权是所有小于 N 的非负整数的一个排列。
定义函数 y = mex(S) ,其中 S 表示一个非负整数集, y 表示不在这个集合里的最小非负整数。
现有 Q 组询问。询问有两个类型:
- 交换两个给定点的点权
- 如果将树上的任意一条简单路径 L 经过的点的点权看作一个非负整数集,则可求得其对应的 mex(L) 。现要求输出当前树上所有路径的最大 mex 值,即 max(mex(L))
题目保证 N 和 Q 不超过 2 × 105
分析
这个题刚看到的时候感觉有两个可能的切入点。题目中间定义的这个 mex 函数可以算是对于前缀的查询,然后交换点权可以看成两次连续的点修改。这也就是说这题可能可以用树状数组或者线段树解决。第二个点就是题目中多次提到的路径,让我有一种树剖的感觉。但是鉴于我不会树剖所以这条路就没往下想...
尝试在线段树的框架下考虑这个问题,如果线段 [L, R] 对应的线段树节点表示对于闭区间 [L, R] 中的所有数都出现的最短简单路径(当然也可能无解),假如我们可以快速维护这棵线段树,那么每次二型询问就可以在 O(T × log2n) 内通过二分地询问线段树完成(其中 T 表示单次维护时间);每次一型询问可以分拆成两个点修改,修改叶子节点以后维护它的所有祖先,复杂度 O(T × logn) 。看起来算是一个不那么爆炸的复杂度了。
下面的问题就是如何维护(合并)两个邻接区间了。首先,线段树的所有叶子区间都有解(宽度为 1)。其次,父节点有解的必要(但是不充分)条件是它的两个儿子都有解。然后我仔细想了一下,树上两条路径可以合并的充要条件应该是这些:
- 两条路径满足包含关系,即某条路径的两个端点都在另一条路径上
- 两条路径各有一个端点在另一条路径上
- 可以从两条路径中各选一个端点使得剩下的两个端点在被选的两个端点确定的路径上
大概感受了一下这些条件跟存在四个端点中的两个使得剩下两个在它们确定的路径上是等价的。于是就转化成了判断点是否在路径上的问题。
这个问题可以用 LCA 来解决, 假设一条路径的端点为 E1 和 E2 ,点 M 在这条路径上的充要条件是:
记 A = LCA(E1, E2) 则
(LCA(E1, M) == M || LCA(E2, M) == M) && LCA(A, M) == A
LCA 的话如果用离散表预处理的话可以做到 O(1) 查询。这样就差不多想清楚了。
然后交了,TLE...
仔细算了一下复杂度,这个 LCA 虽然是 O(1) 的查询,但是在极端情况下一次维护竟然能查询 48 次 LCA 。加上 LCA 本身常数也不小,复杂度会非常爆炸...
所幸我很快就发现,二分在最坏情况下会有非常多的重复询问(因为每次询问都是 [1, x] 的区间)而且询问主要慢在合并区间的操作上...不过这似乎可以一次完成。即先找到最左边的合法线段,然后对于该线段的兄弟节点,尝试将其左边部分二分地合并到可行解尾部并更新答案。这样修改以后单次二型操作的理论复杂度也下降到 O(logn)。然后剪了剪枝就过了...
代码
总复杂度为 O((n + q) × logn)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
const int maxn = 212345;
struct Seg {
int l, r, lv, rv;
} T[maxn << 2];
vector<int> G[maxn];
int n, q, clk, st[maxn << 1][19], in[maxn], out[maxn], dep[maxn], w[maxn];
void dfs(int u, int p) {
in[u] = ++clk;
st[clk][0] = u;
for (int v : G[u]) {
if (v == p) continue;
dep[v] = dep[u] + 1;
dfs(v, u);
st[++clk][0] = u;
}
out[u] = clk;
}
void rmq_init() {
FOR(i, 1, clk) FOR(j, 1, 18) {
st[i][j] = st[i][j - 1];
int val = i - (1 << j - 1);
if (val > 0 && dep[st[val][j - 1]] < dep[st[i][j]])
st[i][j] = st[val][j - 1];
}
}
inline int rmq(int l, int r) {
int val = floor(log(r - l + 1) / log(2));
int u = st[l + (1 << val) - 1][val], v = st[r][val];
return dep[u] < dep[v] ? u : v;
}
inline int lca(int u, int v) {
if (in[u] > in[v]) swap(u, v);
return rmq(in[u], in[v]);
}
inline bool between(int e1, int e2, int m) {
int a = lca(e1, e2);
return lca(a, m) == a && (lca(e1, m) == m || lca(e2, m) == m);
}
inline pii get(const int *buff) {
if (between(buff[0], buff[1], buff[2]) && between(buff[0], buff[1], buff[3]))
return mp(buff[0], buff[1]);
if (between(buff[0], buff[2], buff[1]) && between(buff[0], buff[2], buff[3]))
return mp(buff[0], buff[2]);
if (between(buff[0], buff[3], buff[1]) && between(buff[0], buff[3], buff[2]))
return mp(buff[0], buff[3]);
if (between(buff[1], buff[2], buff[0]) && between(buff[1], buff[2], buff[3]))
return mp(buff[1], buff[2]);
if (between(buff[1], buff[3], buff[0]) && between(buff[1], buff[3], buff[2]))
return mp(buff[1], buff[3]);
if (between(buff[2], buff[3], buff[0]) && between(buff[2], buff[3], buff[1]))
return mp(buff[2], buff[3]);
return mp(-1, -1);
}
void maintain(int o) {
int lson = o << 1, rson = o << 1 | 1;
if (T[lson].lv == -1 || T[rson].lv == -1) {
T[o].lv = T[o].rv = -1;
return;
}
if (T[lson].lv == T[lson].rv) {
T[o].lv = T[lson].lv;
T[o].rv = T[rson].lv;
return;
}
int buff[] = {T[lson].lv, T[lson].rv, T[rson].lv, T[rson].rv};
auto tmp = get(buff);
T[o].lv = tmp._1;
T[o].rv = tmp._2;
}
int solve(int o, int &p1, int &p2) {
if (T[o].lv == -1) {
if (T[o].l == T[o].r) return T[o].l - 1;
int flag = solve(o << 1, p1, p2);
return flag == T[o << 1].r ? solve(o << 1 | 1, p1, p2) : flag;
}
int buff[] = {p1, p2, T[o].lv, T[o].rv};
auto check = get(buff);
if (check._1 != -1) {
p1 = check._1, p2 = check._2;
return T[o].r;
}
if (T[o].lv == T[o].rv) return T[o].l - 1;
int flag = solve(o << 1, p1, p2);
return flag == T[o << 1].r ? solve(o << 1 | 1, p1, p2) : flag;
}
int query(int o) {
while (T[o].lv == -1) o <<= 1;
pii tmp = mp(T[o].lv, T[o].rv);
return solve(o | 1, tmp._1, tmp._2);
}
void update(int o, int pos, int u) {
if (T[o].l == T[o].r) {
T[o].lv = T[o].rv = u;
return;
}
int M = T[o].l + T[o].r >> 1;
pos <= M ? update(o << 1, pos, u) : update(o << 1 | 1, pos, u);
maintain(o);
}
void build(int o, int l, int r) {
T[o].l = l, T[o].r = r;
T[o].lv = T[o].rv = -1;
if (l != r) {
int m = l + r >> 1;
build(o << 1, l, m);
build(o << 1 | 1, m + 1, r);
}
}
int main() {
scanf("%d", &n);
build(1, 1, 1 << 18);
FOR(i, 1, n) {
scanf("%d", w + i);
w[i]++;
}
FOR(i, 2, n) {
int d;
scanf("%d", &d);
G[i].pb(d);
G[d].pb(i);
}
dfs(1, 0);
rmq_init();
FOR(i, 1, n) update(1, w[i], i);
scanf("%d", &q);
while (q--) {
int t;
scanf("%d", &t);
if (t == 1) {
int i, j;
scanf("%d%d", &i, &j);
update(1, w[i], j);
update(1, w[j], i);
swap(w[i], w[j]);
} else {
printf("%d\n", query(1));
}
}
}
网友评论