主席树 学习笔记

模板及讲解

什么是主席树

主席树也称为函数式线段树、可持久化线段树,主要是利用动态开点每个点建线段树(维护$[1,i]$的区间)、线段树可加可减性($[x,y]=[1,y]-[1,x-1]$)来解决如区间内的某些问题。主席树实际上是树套树,最普通的主席树问题就是前缀和套线段树。

主席树的实现

例题

caioj 1441

给$n$($1 \leq n \leq 100000$)个数字,
$a[1],a[2],……,a[n](0 \leq a[i]<=1000000000),m(1 \leq m \leq 100000)$次询问$l$到$r$之间的第$k$小的值。

由题不需要修改操作,就是最普通的主席树问题。

从全局入手

对于整个区间的$k$小,我们可以开权值线段树记录每个值的大小,然后查询时仿造平衡树的方法可以找到第$k$大值。

线段树可加可减性

那么对于区间$[x,y]$,我们怎么办?
想到每个点$i$开$[1,i]$的线段树(整个线段树维护区间不变,只是每个数值的范围,不然不能满足加减性), 则$[x,y]=[1,y]-[1,x-1]$

这样可以看出我们要研究线段树是否可加可减,看下面的例子(借用了 caioj 的图片)

两棵线段树显然可加,并且对应位置上的和相加(维护区间和)。

主席树实现

首先要对每个点开$[1,i]$的线段树,先开一条只包含$i$点信息的链,再与前面一棵线段树合并(相加)。合并线段树也很方便,只要加上$i$点的信息,合并$[1,i-1]$( $merge$ 操作,代码中的$mge$)

查询的时候类似平衡树的查询,例如求$k$小,因为权值线段树,所以左边点都小于这个点,右边点都大于这个点,判断一下第$k$小在左边还是右边,就可以找到了。

代码

注意要离散化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100000 + 5;
int n, m, whw, tax[MAXN], ai[MAXN], rt[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20], sz;
int getPos(int x) {return lower_bound(tax + 1, tax + 1 + whw, x) - tax;}
int mge(int &x, int y) {//合并
if (y == 0) return 0;
if (x == 0) return x = y, 0;//x及其子树与y一致,直接使用
sumv[x] += sumv[y];//合并信息
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos) {//建一条链
if (x == 0) x = ++sz, sumv[x] = 0, lc[x] = rc[x] = 0;//动态开点
sumv[x]++;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos); else build(M + 1, r, rc[x], pos);
}
int query(int l, int r, int x, int y, int kth) {//查询
if (l == r) return l;
int dlt = sumv[lc[y]] - sumv[lc[x]];
if (kth <= dlt) return query(l, M, lc[x], lc[y], kth);
else return query(M + 1, r, rc[x], rc[y], kth - dlt);//类似平衡树查询
}
void clean() {
sz = 0;
for (int i = 1; i <= 2000001; i++) sumv[i] = lc[i] = rc[i] = 0;
for (int i = 1; i <= 100001; i++) tax[i] = ai[i] = rt[i] = 0;
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]), tax[i] = ai[i];
sort(tax + 1, tax + 1 + n), whw = unique(tax + 1, tax + 1 + n) - tax - 1;//离散化
for (int i = 1; i <= n; i++) build(1, whw, rt[i], getPos(ai[i])), mge(rt[i], rt[i - 1]);//建链、合并
for (int x, y, k, i = 1; i <= m; i++) {
scanf("%d%d%d", &x, &y, &k);
printf("%d\n", tax[query(1, whw, rt[x - 1], rt[y], k)]);
}
return 0;
}
int main() {
scanf("%d%d", &n, &m), solve();
return 0;
}

树上主席树

caioj 1443

给定一棵$N(1 \leq N \leq 100000)$个节点的树,每个点有一个权值,对于$M(1 \leq M \leq 100000)$个询问$(x,y,k)$,你需要回答$x$和$y$这两个节点间第$k$小的点权。

我们对于每个点建主席树维护$(u,rt)​$路径链,$rt​$为根,合并时与他的父亲节点合并,计算$(u,v)​$信息线段树时使用
$(u,v)=(u, rt)+(v,rt)-(lca,rt)-(fa[lca], rt)$, $lca=LCA(u,v), fa[lca]$为$lca$的父亲节点$rt$为根,画图理解
然后按照普通的主席树做就行了

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100000 + 5, logs = 18;
int n, q, whw, ai[MAXN], tax[MAXN], rt[MAXN], sz, dep[MAXN], pre[MAXN][25];
vector<int> G[MAXN];
int getPos(int x) {return lower_bound(tax + 1, tax + 1 + whw, x) - tax;}
void ins(int a, int b) {G[a].push_back(b), G[b].push_back(a);}
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20];
int mge(int &x, int y) {
if (y == 0) return 0;
if (x == 0) return x = y, 0;
sumv[x] += sumv[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos) {
if (x == 0) x = ++sz, sumv[x] = lc[x] = rc[x] = 0;
sumv[x]++;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos); else build(M + 1, r, rc[x], pos);
}
int query(int l, int r, int x, int y, int lca, int flca, int kth) {
if (l == r) return tax[l];
int sum = sumv[lc[x]] + sumv[lc[y]] - sumv[lc[lca]] - sumv[lc[flca]];
if (sum >= kth) return query(l, M, lc[x], lc[y], lc[lca], lc[flca], kth);
else return query(M + 1, r, rc[x], rc[y], rc[lca], rc[flca], kth - sum);
}
void dfs(int u, int pa) {
dep[u] = dep[pa] + 1, pre[u][0] = pa, mge(rt[u], rt[pa]);
for (int i = 1; i <= logs; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
for (int i = 0; i < (int)G[u].size(); i++) {
int v = G[u][i];
if (v != pa) dfs(v, u);
}
}
int LCA(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
for (int i = logs; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
if (a == b) return a;
for (int i = logs; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
return pre[a][0];
}
void clean() {
sz = 0;
for (int i = 0; i <= 100001; i++) {
G[i].clear(), dep[i] = tax[i] = ai[i] = rt[i] = 0;
for (int j = 0; j <= 19; j++) pre[i][j] = 0;
}
for (int i = 0; i <= 2000001; i++) sumv[i] = lc[i] = rc[i] = 0;
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]), tax[i] = ai[i];
sort(tax + 1, tax + 1 + n), whw = unique(tax + 1, tax + 1 + n) - tax - 1;
for (int x, y, i = 1; i < n; i++) scanf("%d%d", &x, &y), ins(x, y);
for (int i = 1; i <= n; i++) build(1, whw, rt[i], getPos(ai[i]));
dfs(1, 0);
while (q--) {
int x, y, k, lca;
scanf("%d%d%d", &x, &y, &k);
lca = LCA(x, y);
printf("%d\n", query(1, whw, rt[x], rt[y], rt[lca], rt[pre[lca][0]], k));
}
return 0;
}
int main() {
scanf("%d%d", &n, &q), solve();
return 0;
}
/*
13 100
3 4 1 2 3 2 4 5 3 2 1 1 3
1 2
2 3
3 4
4 5
5 6
5 7
2 8
8 9
9 10
10 11
10 12
11 13
7 13 8
*/

带修主席树

caioj 1442

给$n(1 \leq n \leq 50000)$个数字,进行$m(1 \leq m \leq 10000)$次操作,有两种操作:
$Q,l,r,k$:询问$l$到$r$第$k$小的数。
$C,x,k$:改变第$x$个数的值为$k$。

因为普通的主席树是前缀和套线段树,所以不能修改。那么我们想到修改,就发现可以用树状数组/线段树套线段树,由于此题单点修改,所以用树状数组。
对于前缀和套线段树先建主席树,然后再建树状数组套线段树的,修改在树状数组上操作,原数组在前缀和中,综合可以得到修改后的信息,要注意树状数组上的点在线段树上跳动(jump函数调节,存在$ust$数组),查询就用$ust$数组即可

实际上可以不必要建$2n$棵线段树,原数组直接加到树状数组中即可,不过会慢一点,参见此处

代码

建$2n$棵线段树的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 50000 + 5, MV = 1000000000;
int n, q, sz, rt[MAXN * 2], ai[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 100], lc[MAXN * 100], rc[MAXN * 100], ust[MAXN * 100];
int lowbit(int x) {return x & (-x);}
int mge(int &x, int y) {//线段树 合并 ->将线段树y合并至线段树x
if (x == 0) return x = y, 0;
if (y == 0) return 0;
sumv[x] += sumv[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos, int v) {//线段树 建链 -> 维护[l,r]区间,当前线段树上点x,修改位置为pos=v
if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
sumv[x] += v;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
void add(int u, int x, int c) {//Bit 加 -> bit上u点,x位置加c
for (int i = u; i <= 2 * n; i += lowbit(i)) build(1, MV, rt[i], x, c);
}
void jump(int u, int tp) {//Bit 更新 -> bit上u跳
for (int i = u; i > n; i -= lowbit(i)) {
if (tp == -1) ust[i] = rt[i];
if (tp == 0) ust[i] = lc[ust[i]];
if (tp == 1) ust[i] = rc[ust[i]];
}
}
int getBitSum(int u) {//Bit 查询 -> bit上u查询
int ret = 0;
for (int i = u; i > n; i -= lowbit(i)) {
ret += sumv[lc[ust[i]]];
}
return ret;
}
int query(int l, int r, int x, int y, int x_2, int y_2, int kth) {//线段树 查询 -> 维护[l,r]区间,当前线段树上点x,y, 位置x_2, y_2, 查询第kth大
if (l == r) return l;
int sum = sumv[lc[y]] - sumv[lc[x]] + getBitSum(y_2 + n) - getBitSum(x_2 + n);
if (sum >= kth) {
jump(x_2 + n, 0), jump(y_2 + n, 0);
return query(l, M, lc[x], lc[y], x_2, y_2, kth);
} else {
jump(x_2 + n, 1), jump(y_2 + n, 1);
return query(M + 1, r, rc[x], rc[y], x_2, y_2, kth - sum);
}
}
void clean() {
sz = 0;
for (int i = 0; i <= 100000 + 5; i++) rt[i] = 0;
for (int i = 0; i <= 5000000 + 5; i++) sumv[i] = lc[i] = rc[i] = ust[i] = 0;
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
for (int i = 1; i <= n; i++) build(1, MV, rt[i], ai[i], 1), mge(rt[i], rt[i - 1]);
char s[5];
while (q--) {
scanf("%s", s);
if (s[0] == 'C') {
int x, k; scanf("%d%d", &x, &k);
add(x + n, ai[x], -1), ai[x] = k, add(x + n, ai[x], 1);
} else {
int l, r, k; scanf("%d%d%d", &l, &r, &k);
jump(r + n, -1), jump(l - 1 + n, -1);
printf("%d\n", query(1, MV, rt[l - 1], rt[r], l - 1, r, k));
}
}
return 0;
}
int main() {
scanf("%d%d", &n, &q), solve();
return 0;
}

主席树维护区间问题

spoj DQUERY

给出一个$n$个数的序列,求区间$[l,r]$里有多少种不同数字。

与树状数组类似,主席树维护区间,相当于可持久化维护每次加点后的情况。每个点按顺序建树,如果这个点的数之前没有出现过,直接在本棵主席树该位置加$1$。否则就把之前出现这个值的位置减$1$,再重复做没有出现的情况。为的是把数尽可能放到右边,因为记录值中位置不影响答案。这样就方便求解$[l,r]$的信息。
询问直接用右端点的主席树,由于上述操作后答案可减,所以直接把右端点的主席树左端点以右的值求和即可

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 30000 + 5;
int n, q, sz, ai[MAXN], rt[MAXN], lst[1000000 + 5];
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20];
int mge(int &x, int y) {//主席树x合并主席树y
if (y == 0) return 0;
if (x == 0) return x = y, 0;
sumv[x] += sumv[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos, int v) {//建链维护[l,r], 主席树上x点,修改位置和值
if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
sumv[x] += v;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
int query(int l, int r, int x, int u) {//查询[l,r]答案,主席树上x点,左边临界点u
if (l == r) return 0;
if (u <= M) return sumv[rc[x]] + query(l, M, lc[x], u); //加上右边,查询左边
else return query(M + 1, r, rc[x], u); //不要加左,左边有临界点
}
void clean() {
sz = 0, ms(lst, -1);
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
for (int i = 1; i <= n; i++) {//维护 [0, n] 区间,因为l - 1可能为 0
if (lst[ai[i]] < 0) build(0, n, rt[i], i, 1), mge(rt[i], rt[i - 1]);//之前没有
else {
build(0, n, rt[i], lst[ai[i]], -1), build(0, n, rt[i], i, 1);
mge(rt[i], rt[i - 1]);
}//之前有
lst[ai[i]] = i;
}
scanf("%d", &q);
while (q--) {
int l, r; scanf("%d%d",&l, &r);
printf("%d\n", query(0, n, rt[r], l - 1));
}
return 0;
}
int main() {
scanf("%d", &n), solve();
return 0;
}

可持久化

caioj 1447

维护区间和,有区间增加,要求可持久化 (回退、查询某个版本)

每个询问开一棵线段树,回退直接删掉中间的线段树即可。由于是主席树不能pushdown,pushup,所以增加的时候直接更新sumv的值,查询时lazy值直接累加乘以查询区间长度即可,具体操作可以看代码

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const LL MAXN = 100000 + 5;
LL n, m, ai[MAXN], qzh[MAXN], rt[MAXN], now, sz;
#define M ((l + r) >> 1)
LL sumv[MAXN * 40], lc[MAXN * 40], rc[MAXN * 40], lazy[MAXN * 40];
int mge(LL &x, LL y) {
if (y == 0) return 0;
if (x == 0) return x = y, 0;
sumv[x] += sumv[y], lazy[x] += lazy[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(LL l, LL r, LL &x, LL cl, LL cr, LL v) {
if (x == 0) x = ++sz;
sumv[x] += (cr - cl + 1) * v;//直接加,免去pushup
if (l == cl && r == cr) {
lazy[x] += v;
return ;
}
if (cr <= M) build(l, M, lc[x], cl, cr, v); else if (cl > M) build(M + 1, r, rc[x], cl, cr, v);
else build(l, M, lc[x], cl, M, v), build(M + 1, r, rc[x], M + 1, cr, v);
//整个区间在左边、右边、分开两边
}
LL query(LL l, LL r, LL x, LL cl, LL cr, LL tmp) {
if (l == cl && r == cr) return (r - l + 1) * tmp + sumv[x];
if (cr <= M) return query(l, M, lc[x], cl, cr, tmp + lazy[x]);
else if (cl > M) return query(M + 1, r, rc[x], cl, cr, tmp + lazy[x]);
else return query(l, M, lc[x], cl, M, tmp + lazy[x]) + query(M + 1, r, rc[x], M + 1, cr, tmp + lazy[x]);
//整个查询区间在左边、右边、分开两边,和普通线段树不同,相当于用 M 分离查询区间
//直接累加lazy最后乘查询区间长度
}
void clean() {
now = sz = 0;
for (LL i = 0; i <= 100000 + 3; i++) rt[i] = qzh[i] = 0;
for (LL i = 0; i <= 4000000 + 3; i++) sumv[i] = lc[i] = rc[i] = lazy[i] = 0;
}
int solve() {
clean();
for (LL i = 1; i <= n; i++) scanf("%lld", &ai[i]), qzh[i] = qzh[i - 1] + ai[i];
for (LL i = 1; i <= m; i++) {
LL tp; scanf("%lld", &tp);
if (tp == 1) {
LL l, r, k; scanf("%lld%lld%lld", &l, &r, &k);
build(1, n, rt[++now], l, r, k), mge(rt[now], rt[now - 1]);
}
if (tp == 2) {
LL l, r; scanf("%lld%lld", &l, &r);
printf("%lld\n", qzh[r] - qzh[l - 1] + query(1, n, rt[now], l, r, 0));
}
if (tp == 3) {
LL l, r, h; scanf("%lld%lld%lld", &l, &r, &h);
printf("%lld\n", qzh[r] - qzh[l - 1] + query(1, n, rt[h], l, r, 0));
}
if (tp == 4) {
LL h; scanf("%lld", &h);
for (LL i = h + 1; i <= now; i++) rt[i] = 0;
now = h;
}
}
return 0;
}
int main() {
scanf("%lld%lld", &n, &m), solve();
return 0;
}

常见题型

1、开权值线段树维护
2、树上主席树
3、带修主席树
4、开普通线段树维护
5、可持久化线段树

------ 本文结束 ------