平衡树-Treap 及其拓展
平衡树分 leafy 和非 leafy。leafy 平衡树在叶子节点上存储原始信息,像线段树。非 leafy 则是每个节点上都有信息。实际上各种树型 $\log$ 之类的数据结构本质上都类似(比如平衡树、线段树、跳表),只是具体实现不同。
讲平衡树前,首先要讲一下排序二叉树(BST)是啥。排序二叉树,显然要是一个二叉树。每个节点上有一个 $val$,这棵树的中序遍历是不降的。也就是说,一个节点的左子树中所有的节点的 $val$ 都小于等于这个节点的 $val$,而它的右子树中所有节点的 $val$ 都大于这个节点的 $val$。于是,可以在 ${O}(\texttt{树高})$ 的复杂度内完成许多操作,比如查询第 $k$ 大的 $val$、查询小于 $k$ 的 $val$ 的个数。
考虑如何做这些操作(均设当前节点为 $p$):
- 插入 $val$:从根开始向下找,若 $p_{val}<val$,则递归至右孩子。否则递归至左孩子。当 $p$ 为空时,新建节点。
- 找到第 $k$ 大的 $val$:维护前每个节点的子树大小。从根开始向下找,如果左孩子子树大小大于等于 $k$ 时,递归至左孩子。如果左孩子子树大小等于 $k-1$,则返回当前节点的 $val$。否则,将 $k$ 减去左孩子子树大小再减去 $1$(当前节点),并递归右孩子。
- 其它也类似,大概就是从根开始,根据一些信息来判断向左或向右递归,再去做些操作。
具体实现时,可以用指针表示左右孩子,可以用成员函数来写,也可以用下标表示,可以用普通函数写。建议用下标表示孩子,并用结构体封装节点,用引用来操作(即代码中的 node &p = tr[rt];
)。
可以做 Luogu P5076。
一部分 Code
struct node {
int val;
int siz; // 本子树大小
int lc, rc; // 左右孩子编号,0 为无孩子
} tr[N];
void upd(int rt) {
if(!rt) return;
node &p = tr[rt];
p.siz = tr[p.lc].siz + 1 + tr[p.rc].siz;
}
int find_val(int rt, int k) { // 找到第 k 大的 val
if(!rt) return -1;
node &p = tr[rt];
if(tr[p.lc].siz + 1 > k) return find_val(p.lc, k);
else if(tr[p.lc].siz + 1 == k) return p.val;
else return find_val(p.rc, k - tr[p.lc].siz - 1);
}
int find_num(int rt, int k) { // 小于 k 的 val 的个数
if(!rt) return 0;
node &p = tr[rt];
if(p.val < k) return find_num(p.rc, k) + tr[p.lc].siz + 1;
else return find_num(p.lc, k);
}
int find_pre(int rt, int k) { // 小于 k 的最大值
if(!rt) return −2147483647;
node &p = tr[rt];
if(p.val < k) return max(p.val, find_pre(p.rs, k));
else return find_pre(p.ls, k);
}
int insert(int rt, int val) {
if(!rt) return tr[++cnt] = node{val,1,0,0}, cnt;
node &p = tr[rt];
if(p.val < val) p.rc = insert(p.rc, val);
else p.lc = insert(p.lc, val);
upd(rt);
return rt;
}
其它的也都易于实现。但显然的,如果连续插入严格升序的 $val$,复杂度是 $O(n^2)$ 的。于是,需要降低树高。
平衡树是什么?平衡树就是比较平衡的 BST,每个节点左子树大小和右子树大小差不多(非严格定义)。设一个平衡树共 $n$ 个节点,那如果左右子树大小差不多,显然树高是 $\mathrm{O}(\log n)$ 级别的。那么如何把 BST 搞平衡呢?很简单,将每个节点再赋一个随机的 $pri$,使得每个节点的父节点的 $pri$ 比它大,也就是说按 $pri$ 就是一个大根堆(小根堆也可以)。这样的树高期望就是 $O(\log n)$ 的了。作者不太会证明。
然后思考如何插入。
(以下为 Treap 的分裂合并式讲解)
此时,可以发现,这颗树可以分裂成两棵有序的树。可以按 $val$ 分裂(小于 $val$ 的在左,其余的在右),也可以按树的大小分裂。
Code
struct pir { int a, b; }; // a 和 b 含义:把一个树裂成两个之后的树根编号,0 则为空
pir split_val(int rt, int val) {
if(!rt) return pir{0, 0};
node &p = tr[rt];
if(p.val < val) {
pir o = split_val(p.rs, val);
p.rs = o.a; o.a = rt; upd(rt);
return o;
} else {
pir o = split_val(p.ls, val);
p.ls = o.b; o.b = rt; upd(rt);
return o;
}
}
pir split_rk(int rt, int k) { // 分裂后左树的大小为 k
if(!rt) return pir{0, 0};
node &p = tr[rt];
if(tr[p.ls].siz + 1 <= k) {
pir o = split_rk(p.rs, k - tr[p.ls].siz - 1);
p.rs = o.a; o.a = rt; upd(rt);
return o;
} else {
pir o = split_rk(p.ls, k);
p.ls = o.b; o.b = rt; upd(rt);
return o;
}
}
也可以发现,如果有两个有序的树(左边的树中最大值小于等于右边的最小值),可以把它们合并。由于树有序了,所以合并时只需要考虑 $pri$ 的大小。$pri$ 较大的放在根,并将某一边的子树和另一个合并。
Code
int merge(int a, int b) {
if(!a || !b) return a + b;
node &pa = tr[a], &pb = tr[b];
if(pa.pri > pb.pri) {
pa.rs = merge(pa.rs, b);
upd(a); return a;
} else {
pb.ls = merge(a, pb.ls);
upd(b); return b;
}
}
如果要插入一个值,就直接将树按这个值裂开为两个,设较小的为 $a$,较大的为 $b$,再新建一个单独的节点表示所插入的值。之后将 $a$ 和这个节点合并,再合并上 $b$ 就行。
Code
int insert(int rt, int val){
pir o = split_val(rt, val);
tr[++cnt] = node{val, rand(), 1, 0, 0};// val, pri, siz, lc, rc
o.a = merge(o.a, cnt);
return merge(o.a, o.b);
}
其余的没什么变化。也可以将各个查询都按分裂,取出节点再合并的方式进行,但是常数会大些,也不很好写。
此时的这棵树就已经可以通过普通平衡树了(Luogu 模板:P6136)。
同时,由于可以分裂合并,Treap 可以维护(单个或多个)序列,这时 $val$ 就是下标。但是为了在中间插入元素,可以直接舍弃掉 $val$ 而按照大小分裂。这时某个节点的排名就是它的下标。序列的信息可以储存在每个节点中,由此也可以实现区间加,区间翻转,区间平移之类的操作。类似于线段树,都是打一个 tag,并下传。
区间翻转 Code
// 需要一个tag表示是否翻转,同时 pd 函数应该在所有向下找的函数中使用,类似于线段树的 push_down
void pd(int rt) {
if(!rt) return;
node &p = tr[rt];
if(!p.tg) return;
if(p.ls) { node &ls = tr[p.ls]; ls.tg ^= 1; swap(ls.ls, ls.rs); }
if(p.rs) { node &rs = tr[p.rs]; rs.tg ^= 1; swap(rs.ls, rs.rs); }
p.tg = 0;
}
void revers(int l, int r) {
pir o = split_rk(rt, r), v = split_rk(o.a, l - 1);
node &p = tr[v.b];
p.tg ^= 1; swap(p.ls, p.rs);
o.a = merge(o2.a, o2.b);
rt = merge(o.a, o.b);
}
维护序列的典型题:文艺平衡树,【NOI2005】维护数列,【NOI2021】密码箱。
也可以使用 Treap 来维护环。它与维护序列类似,但需要分类讨论,断开哪里连接哪里,画画图会好很多。也需要维护一级祖先节点,用来找到某个节点的根和下标。
维护环的例题:【蓝桥杯 2016 国 A】圆圈舞(巨大难写,感觉写完了就真的明白它是如何做的了)。
可能需要推一下式子()
放一个这道题的代码,如果维护环讲的不清楚可以看代码()
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int uint;
ll read(){
ll a = 0, b = 0; char c = getchar();
while (c < '0' || c > '9') b ^= (c == '-'), c = getchar();
while (c >= '0' && c <= '9') a = a * 10 - 48 + c, c = getchar();
return b ? -a : a;
}
const int N = 100005, NS = N, mod = 1000000007;
mt19937 rnd(20220921);
int n, m;
ll ans;
struct node {
uint pri;
int ls, rs, pa, siz;
ll H, F, sH, sF, pH, pF, c;
ll calc() { return ((sF * pH - pF * sH + siz * c) % mod + mod) % mod; }
void prin(){
cerr<<"node print: "<<ls<<' '<<rs<<' '<<pa<<' '<<siz<<" "
<<H<<' '<<F<<' '<<sH<<' '<<sF<<' '<<pH<<' '<<pF<<' '<<c<<'\n';
}
};
struct pir { int a ,b; };
struct treap {
node tr[NS];
void upd(int rt) {
if (!rt) return;
node &p = tr[rt], &l = tr[p.ls], &r = tr[p.rs];
p.sH = l.sH + p.H + r.sH;
p.sF = l.sF + p.F + r.sF;
p.pH = l.pH + l.siz*p.H + (l.siz+1)*r.sH + r.pH;
p.pF = l.pF + l.siz*p.F + (l.siz+1)*r.sF + r.pF;
p.siz = l.siz + 1 + r.siz;
p.c = l.c + r.c + l.sH*p.F + p.H*r.sF + l.sH*r.sF;
p.sH %= mod, p.sF %= mod, p.pH %= mod, p.pF %= mod, p.c %= mod;
l.pa = r.pa = rt;
}
int merge(int a, int b) {
if (!a || !b) return a + b;
node &u = tr[a], &v = tr[b];
if (u.pri > v.pri) {
u.rs = merge(u.rs, b);
upd(a); return a;
} else {
v.ls = merge(a, v.ls);
upd(b); return b;
}
}
pir find(int a) { // 找到点 a 的根和下标
pir s = pir{a, tr[tr[a].ls].siz+1};
while (tr[a].pa != a) {
node &p = tr[a], &q = tr[p.pa];
int pa = p.pa;
if (q.ls == a) {
s.a = pa; a = pa;
} else {
s.a = pa, s.b += tr[q.ls].siz+1;
a = pa;
}
}
return s;
}
pir split2(int rt, int k) { // 按下标分裂(左侧有 k 个元素)
if (!rt) return pir{0, 0};
node &p = tr[rt];
if (k <= tr[p.ls].siz) {
pir o = split2(p.ls, k);
tr[p.ls].pa = p.ls;
p.ls = o.b, o.b = rt, upd(rt);
return o;
} else {
pir o = split2(p.rs, k - tr[p.ls].siz - 1);
tr[p.rs].pa = p.rs;
p.rs = o.a, o.a = rt, upd(rt);
return o;
}
}
ll calc(int rt){
if(!rt)return 0;
return tr[rt].calc();
}
void init(int n) {
int rt = 0;
for (int i = 1; i <= n; i++) {
ll H = read(), F = read();
tr[i] = node{(uint)rnd(), 0, 0, i, 1, H, F, H, F, 0, 0, 0};
upd(i);
rt = merge(rt, i);
}
ans = calc(rt);
}
void swp(int a, int b) {
pir ra = find(a), rb = find(b);
if (ra.a == rb.a) { // 根相同,在同一环内,将会分裂成两个环
ans -= calc(rb.a);
if (ra.b < rb.b) {
pir sb = split2(rb.a, rb.b - 1), sa = split2(sb.a, ra.b);
int r1 = merge(sa.a, sb.b);
ans += calc(r1) + calc(sa.b);
} else {
pir sa = split2(rb.a, ra.b), sb = split2(sa.a, rb.b - 1);
int r1 = merge(sb.a, sa.b);
ans += calc(r1) + calc(sb.b);
}
} else { //将会合并成一个环
ans -= calc(ra.a) + calc(rb.a);
pir sa = split2(ra.a, ra.b), sb = split2(rb.a, rb.b - 1);
int r1 = merge(sa.a, sb.b);
r1 = merge(sa.b, r1);
r1 = merge(r1, sb.a);
ans += calc(r1);
}
}
void uupd(int a) { // 单点更新
if (tr[a].pa == a) ans -= calc(a);
upd(a);
while (tr[a].pa != a) {
node &p = tr[a];
int pa = p.pa;
if(tr[pa].pa == pa) ans -= calc(pa);
upd(pa); a = pa;
}
ans += calc(a);
}
void chg(int op, int pos, int q) {
node &p = tr[pos];
if (op == 1) p.H = q;
else p.F = q;
uupd(pos);
}
}s;
int main(){
n = read();
s.init(n);
m = read();
for (int i = 1; i <= m; i++) {
int k = read(), p = read(), q = read();
if (k == 1) s.swp(p, q);
else if (k == 2) s.chg(1, p, q);
else s.chg(2, p, q);
ans = (ans % mod + mod) % mod;
cout << ans << '\n';
}
return 0;
}