「Bzoj 3991」「SDOI2015」寻宝游戏 (路径并 + 倍增LCA + 虚树)

BZOJ 3991
题意:$n$个点的树,$m$次变动使得某个点有宝物或没宝物,询问每次变动后集齐所有宝物并返回原点的最小距离。
转化成有根树,求路径的并。路径并就是 DFS 序下两个点之间的距离和。并且最后一个点和第一个点的距离要贡献。那么这题相当于插入一个点然后找到他 DFS 序前面后面的点加上贡献即可。删除同理。注意如果是插在中间要减掉前面后面的距离,因为这个贡献在点加入后已经不存在了,留下来会重复。
对于这个的维护,我们用 set 就行了。
set的一个小Trick: set 里加入 $INF$ 和 $-INF$,然后就不用考虑是不是set.begin()或者set.end()了。

知识点
1、 set 里加入 $INF$ 和 $-INF$,然后就不用考虑是不是set.begin()或者set.end()
2、树上路径 / 边 - 端点LCA相关
3、路径并,路径交

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky { 

    const int MAXN = 100000 + 5, LOGS = 20;
    const LL INF = 10000000;

    struct edge {
        int v, w, nxt;
    } ed[MAXN * 2];

    int n, m, en, hd[MAXN], pre[MAXN][25], dfn[MAXN], dep[MAXN], sz, vis[MAXN];
    LL dis[MAXN];

    struct node {
        int u;
        bool operator < (const node &rhs) const {return dfn[u] < dfn[rhs.u];}
    };

    set<node > s;

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void dfs(int u, int fa) {
        pre[u][0] = fa, dfn[u] = ++sz, dep[u] = dep[fa] + 1;
        for (int i = 1; i <= LOGS; ++i) pre[u][i] = pre[pre[u][i - 1]][i - 1];
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa) dis[e.v] = dis[u] + (LL)e.w, dfs(e.v, u);
        }
    }
    int LCA(int a, int b) {
        if (dep[a] < dep[b]) swap(a, b);
        for (int i = LOGS; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
        if (a == b) return a;
        for (int i = LOGS; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
        return pre[a][0];
    }
    LL dist(int x, int y) {return dis[x] + dis[y] - 2ll * dis[LCA(x, y)];}

    void clean() {
        sz = en = 0, ms(hd, -1), ms(vis, 0);
    }
    int solve() {
        clean();
        scanf("%d%d", &n, &m);
        for (int x, y, w, i = 1; i < n; ++i) {
            scanf("%d%d%d", &x, &y, &w);
            ins(x, y, w), ins(y, x, w);
        }
        dfs(1, 0);
        dfn[n + 1] = -INF, dfn[n + 2] = INF;
        s.insert((node){n + 1}), s.insert((node){n + 2});
        LL ans = 0ll;
        while (m--) {
            int t; scanf("%d", &t);
            if (vis[t]) {
                set<node >::iterator nxt = s.upper_bound((node){t});
                set<node >::iterator pre = s.lower_bound((node){t}); // 注意 set 里有 t
                --pre;
                int fl = 0;
                if (pre->u <= n && pre->u != t) {
                    ans -= dist(pre->u, t);
                    ++fl;
                } 
                if (nxt->u <= n) {
                    ans -= dist(nxt->u, t);
                    ++fl;
                }
                if (fl == 2) ans += dist(nxt->u, pre->u);
                s.erase(s.find((node){t}));
            } else {
                set<node >::iterator nxt = s.upper_bound((node){t});
                set<node >::iterator pre = --nxt;
                ++nxt;
                int fl = 0;
                if (pre->u <= n && pre->u != t) {
                    ans += dist(pre->u, t);
                    ++fl;
                } 
                if (nxt->u <= n) {
                    ans += dist(nxt->u, t);
                    ++fl;
                }
                if (fl == 2) ans -= dist(nxt->u, pre->u);
                s.insert((node){t});
            }
            vis[t] ^= 1;
            set<node >::iterator it = s.upper_bound((node){n + 1});
            set<node >::iterator it2 = s.lower_bound((node){n + 2});
            it2--;
            LL gg = dist(it->u, it2->u);
            printf("%lld\n", ans + gg);
        }
        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------