Bzoj 3991(路径并 + 倍增LCA + 虚树)

BZOJ 3991
题意:$n$个点的树,$m$次变动使得某个点有宝物或没宝物,询问每次变动后集齐所有宝物并返回原点的最小距离。
转化成有根树,求路径的并。路径并就是 DFS 序下两个点之间的距离和。并且最后一个点和第一个点的距离要贡献。那么这题相当于插入一个点然后找到他 DFS 序前面后面的点加上贡献即可。删除同理。注意如果是插在中间要减掉前面后面的距离,因为这个贡献在点加入后已经不存在了,留下来会重复。
对于这个的维护,我们用 set 就行了。
set的一个小Trick: set 里加入 $INF$ 和 $-INF$,然后就不用考虑是不是set.begin()或者set.end()了。

知识点
1、 set 里加入 $INF$ 和 $-INF$,然后就不用考虑是不是set.begin()或者set.end()
2、树上路径 / 边 - 端点LCA相关
3、路径并,路径交

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<set>
#define LL long long
#define db double
#define ms(i, j) memset(i, j, sizeof i)
#define fir first
#define sec second
using namespace std;
namespace flyinthesky {
const int MAXN = 100000 + 5, ZINF = 2000000000;
const int LOGS = 20;
struct edge {int v, w, nxt;} ed[MAXN * 2];
int n, m, en, sz, hd[MAXN], dep[MAXN], dfn[MAXN], pre[MAXN][30], vis[MAXN];
LL dis[MAXN], ans;
struct node {
int u;
bool operator < (const node &rhs) const {return dfn[u] < dfn[rhs.u];};
};
set<node > s;
void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1, pre[u][0] = fa, dfn[u] = ++sz;
for (int i = 1; i <= LOGS; ++i) pre[u][i] = pre[pre[u][i - 1]][i - 1];
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa) dis[e.v] = dis[u] + (LL)e.w, dfs(e.v, u);
}
}
int LCA(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
for (int i = LOGS; i >= 0; --i) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
if (a == b) return a;
for (int i = LOGS; i >= 0; --i) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
return pre[a][0];
}
void clean() {
ans = 0ll, sz = en = 0, ms(dfn, 0), ms(hd, -1), ms(dis, 0), ms(dep, 0), ms(pre, 0), ms(vis, 0);
}
int solve() {
scanf("%d%d", &n, &m);
for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
dfs(1, 0);
// for (int i = 1; i <= n; ++i) cerr << dep[i] << endl;
// for (int i = 1; i <= n; ++i) cerr << dis[i] << endl;
// for (int i = 1; i <= n; ++i) cerr << dfn[i] << endl;
// for (int i = 1; i <= n; ++i) for (int j = 1; j <= n; ++j) cerr << i << " " << j << " " << LCA(i, j) << endl;
dfn[n + 1] = ZINF, dfn[n + 2] = -ZINF;
s.insert((node){n + 1}), s.insert((node){n + 2});
while (m--) {
int x; scanf("%d", &x);
if (!vis[x]) { // add
vis[x] = 1, s.insert((node){x});
set<node >::iterator it1 = s.lower_bound((node){x});
set<node >::iterator it2 = s.upper_bound((node){x});
int flag = 0;
if (dfn[(--it1)->u] != -ZINF) ++flag, ans += dis[it1->u] + dis[x] - 2ll * dis[LCA(it1->u, x)];
if (dfn[it2->u] != ZINF) ++flag, ans += dis[it2->u] + dis[x] - 2ll * dis[LCA(it2->u, x)];
// cerr << "!!!" << it1->u << " " << it2->u << endl;
LL etr = 0ll;
if (s.size() >= 4) {
set<node >::iterator it3 = s.find((node){n + 1});
set<node >::iterator it4 = s.find((node){n + 2});
--it3, ++it4;
etr = dis[it3->u] + dis[it4->u] - 2ll * dis[LCA(it3->u, it4->u)];
// cerr << "!!!" << it3->u << " " << it4->u << endl;
}
// cerr << "???" << etr << " " << ans << endl;
if (flag == 2) ans -= dis[it1->u] + dis[it2->u] - 2ll * dis[LCA(it2->u, it1->u)];
printf("%lld\n", ans + etr);//%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d
} else { // delete
vis[x] = 0, s.erase(s.find((node){x}));
set<node >::iterator it1 = s.lower_bound((node){x});
set<node >::iterator it2 = s.upper_bound((node){x});
int flag = 0;
if (dfn[(--it1)->u] != -ZINF) ++flag, ans -= dis[it1->u] + dis[x] - 2ll * dis[LCA(it1->u, x)];
if (dfn[it2->u] != ZINF) ++flag, ans -= dis[it2->u] + dis[x] - 2ll * dis[LCA(it2->u, x)];
LL etr = 0ll;
if (s.size() >= 4) {
set<node >::iterator it3 = s.find((node){n + 1});
set<node >::iterator it4 = s.find((node){n + 2});
--it3, ++it4;
etr = dis[it3->u] + dis[it4->u] - 2ll * dis[LCA(it3->u, it4->u)];
}
if (flag == 2) ans += dis[it1->u] + dis[it2->u] - 2ll * dis[LCA(it2->u, it1->u)];
printf("%lld\n", ans + etr);//%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d%I64d
}
}
return 0;
}
};
int main() {
flyinthesky::solve();
return 0;
}
/*
7 10
1 2 33
2 3 11
2 4 22
2 5 88
3 6 77
1 7 55
*/

------ 本文结束 ------