caioj 1442(带修主席树)

caioj 1447
题意:给$n(1 \leq n \leq 50000)$个数字,进行$m(1 \leq m \leq 10000)$次操作,有两种操作:

$Q,l,r,k$:询问$l$到$r$第$k$小的数。
$C,x,k$:改变第$x$个数的值为$k$。

因为普通的主席树是前缀和套线段树,所以不能修改。那么我们想到修改,就发现可以用树状数组/线段树套线段树,由于此题单点修改,所以用树状数组。
对于前缀和套线段树先建主席树,然后再建树状数组套线段树的,修改在树状数组上操作,原数组在前缀和中,综合可以得到修改后的信息,要注意树状数组上的点在线段树上跳动(jump函数调节,存在$ust$数组),查询就用$ust$数组即可

实际上可以不必要建$2n$棵线段树,原数组直接加到树状数组中即可,不过会慢一点,代码在下面

建$2n$棵线段树的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 50000 + 5, MV = 1000000000;
int n, q, sz, rt[MAXN * 2], ai[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 100], lc[MAXN * 100], rc[MAXN * 100], ust[MAXN * 100];
int lowbit(int x) {return x & (-x);}
int mge(int &x, int y) {//线段树 合并 ->将线段树y合并至线段树x
if (x == 0) return x = y, 0;
if (y == 0) return 0;
sumv[x] += sumv[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos, int v) {//线段树 建链 -> 维护[l,r]区间,当前线段树上点x,修改位置为pos=v
if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
sumv[x] += v;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
void add(int u, int x, int c) {//Bit 加 -> bit上u点,x位置加c
for (int i = u; i <= 2 * n; i += lowbit(i)) build(1, MV, rt[i], x, c);
}
void jump(int u, int tp) {//Bit 更新 -> bit上u跳
for (int i = u; i > n; i -= lowbit(i)) {
if (tp == -1) ust[i] = rt[i];
if (tp == 0) ust[i] = lc[ust[i]];
if (tp == 1) ust[i] = rc[ust[i]];
}
}
int getBitSum(int u) {//Bit 查询 -> bit上u查询
int ret = 0;
for (int i = u; i > n; i -= lowbit(i)) {
ret += sumv[lc[ust[i]]];
}
return ret;
}
int query(int l, int r, int x, int y, int x_2, int y_2, int kth) {//线段树 查询 -> 维护[l,r]区间,当前线段树上点x,y, 位置x_2, y_2, 查询第kth大
if (l == r) return l;
int sum = sumv[lc[y]] - sumv[lc[x]] + getBitSum(y_2 + n) - getBitSum(x_2 + n);
if (sum >= kth) {
jump(x_2 + n, 0), jump(y_2 + n, 0);
return query(l, M, lc[x], lc[y], x_2, y_2, kth);
} else {
jump(x_2 + n, 1), jump(y_2 + n, 1);
return query(M + 1, r, rc[x], rc[y], x_2, y_2, kth - sum);
}
}
void clean() {
sz = 0;
for (int i = 0; i <= 100000 + 5; i++) rt[i] = 0;
for (int i = 0; i <= 5000000 + 5; i++) sumv[i] = lc[i] = rc[i] = ust[i] = 0;
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
for (int i = 1; i <= n; i++) build(1, MV, rt[i], ai[i], 1), mge(rt[i], rt[i - 1]);
char s[5];
while (q--) {
scanf("%s", s);
if (s[0] == 'C') {
int x, k; scanf("%d%d", &x, &k);
add(x + n, ai[x], -1), ai[x] = k, add(x + n, ai[x], 1);
} else {
int l, r, k; scanf("%d%d%d", &l, &r, &k);
jump(r + n, -1), jump(l - 1 + n, -1);
printf("%d\n", query(1, MV, rt[l - 1], rt[r], l - 1, r, k));
}
}
return 0;
}
int main() {
scanf("%d%d", &n, &q), solve();
return 0;
}

建$n$棵线段树直接维护树状数组:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 50000 + 5, MV = 1000000000;
int n, q, sz, rt[MAXN], ai[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 100], lc[MAXN * 100], rc[MAXN * 100], ust[MAXN * 100];
int lowbit(int x) {return x & (-x);}
int mge(int &x, int y) {
if (x == 0) return x = y, 0;
if (y == 0) return 0;
sumv[x] += sumv[y];
mge(lc[x], lc[y]), mge(rc[x], rc[y]);
return 0;
}
void build(int l, int r, int &x, int pos, int v) {
if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
sumv[x] += v;
if (l == r) return ;
if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
void add(int u, int x, int c) {
for (int i = u; i <= n; i += lowbit(i)) build(1, MV, rt[i], x, c);
}
void jump(int u, int tp) {
for (int i = u; i > 0; i -= lowbit(i)) {
if (tp == -1) ust[i] = rt[i];
if (tp == 0) ust[i] = lc[ust[i]];
if (tp == 1) ust[i] = rc[ust[i]];
}
}
int getBitSum(int u) {
int ret = 0;
for (int i = u; i > 0; i -= lowbit(i)) {
ret += sumv[lc[ust[i]]];
}
return ret;
}
int query(int l, int r, int x, int y, int x_2, int y_2, int kth) {
if (l == r) return l;
int sum = getBitSum(y_2) - getBitSum(x_2);
if (sum >= kth) {
jump(x_2, 0), jump(y_2, 0);
return query(l, M, lc[x], lc[y], x_2, y_2, kth);
} else {
jump(x_2, 1), jump(y_2, 1);
return query(M + 1, r, rc[x], rc[y], x_2, y_2, kth - sum);
}
}
void clean() {
sz = 0;
for (int i = 0; i <= 50000 + 5; i++) rt[i] = 0;
for (int i = 0; i <= 5000000 + 5; i++) sumv[i] = lc[i] = rc[i] = ust[i] = 0;
}
int solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
for (int i = 1; i <= n; i++) add(i, ai[i], 1);
char s[5];
while (q--) {
scanf("%s", s);
if (s[0] == 'C') {
int x, k; scanf("%d%d", &x, &k);
add(x, ai[x], -1), ai[x] = k, add(x, ai[x], 1);
} else {
int l, r, k; scanf("%d%d%d", &l, &r, &k);
jump(r, -1), jump(l - 1, -1);
printf("%d\n", query(1, MV, rt[l - 1], rt[r], l - 1, r, k));
}
}
return 0;
}
int main() {
scanf("%d%d", &n, &q), solve();
return 0;
}

------ 本文结束 ------