「Bzoj 1500」「NOI2005」维修数列 (Splay维护序列)

BZOJ 1500
题意:请写一个程序,要求维护一个数列,支持以下 $6$ 种操作:
请注意,格式栏中的下划线表示实际输入文件中的空格
.png)

用 splay 维护序列 (舍弃二叉排序树的左右儿子大小比较):
插入:新建一棵 splay 插到原树上
删除:提取区间以后打删除标记,注意卡内存用辣鸡回收优化
修改:提取区间以后打标记,注意这里标记都是标时即改
翻转:提取区间以后打标记,若有修改标记则无需做任何事情
求和:提取区间输出区间和
求最长字段和:维护 $ls, rs$分别表示左起最长区间和,右起最长区间和。
则可以通过维护得到。

具体看Splay 学习笔记

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 1000000 + 5, INF = 1000000000;

    int n, m, arr[MAXN];
    int ch[MAXN][2], fa[MAXN], siz[MAXN], val[MAXN], ls[MAXN], rs[MAXN], gss[MAXN], sum[MAXN], lazy[MAXN], upd[MAXN], rt, ncnt;
    char s[10];

    queue<int > q;

    int rel(int x) {return ch[fa[x]][1] == x;}
    void pushup(int x) {
        int l = ch[x][0], r = ch[x][1];
        siz[x] = siz[l] + siz[r] + 1;
        ls[x] = max(ls[l], sum[l] + ls[r] + val[x]);
        rs[x] = max(rs[r], sum[r] + rs[l] + val[x]);
        sum[x] = sum[l] + sum[r] + val[x];
        gss[x] = max(max(gss[l], gss[r]), rs[l] + ls[r] + val[x]);
    }
    void pushdown(int x) {
        int l = ch[x][0], r = ch[x][1];
        if (upd[x]) {
            upd[l] = upd[r] = 1;
            if (l) val[l] = val[x], sum[l] = val[x] * siz[l], ls[l] = rs[l] = max(sum[l], 0), gss[l] = max(sum[l], val[x]);
            if (r) val[r] = val[x], sum[r] = val[x] * siz[r], ls[r] = rs[r] = max(sum[r], 0), gss[r] = max(sum[r], val[x]);
            upd[x] = lazy[x] = 0;
        }
        if (lazy[x]) {
            lazy[l] ^= 1, lazy[r] ^= 1;
            swap(ch[l][0], ch[l][1]), swap(ch[r][0], ch[r][1]);
            swap(ls[l], rs[l]), swap(ls[r], rs[r]);
            lazy[x] = 0;
        }
    }
    void rotate(int x) {
        pushdown(fa[x]), pushdown(x);
        int y = fa[x], z = fa[y], k = rel(x), w = ch[x][k ^ 1];
        ch[z][rel(y)] = x, fa[x] = z;
        ch[y][k] = w, fa[w] = y;
        ch[x][k ^ 1] = y, fa[y] = x;
        pushup(y), pushup(x);
    }
    void splay(int x, int gl = 0) {
        while (fa[x] != gl) {
            pushdown(fa[x]), pushdown(x);
            int y = fa[x], z = fa[y];
            if (z != gl) {
                if (rel(x) == rel(y)) rotate(y); else rotate(x);
            }
            rotate(x);
        }
        if (!gl) rt = x;
    }
    int kth(int k) {
        int cur = rt;
        while (1) {
            pushdown(cur);
            if (k <= siz[ch[cur][0]]) cur = ch[cur][0];
            else if (k > siz[ch[cur][0]] + 1) k -= siz[ch[cur][0]] + 1, cur = ch[cur][1];
            else return cur;
        }
    }
    int newNode(int v) {
        int cur;
        if (q.empty()) cur = ++ncnt; else cur = q.front(), q.pop();
        ch[cur][0] = ch[cur][1] = 0, fa[cur] = 0, siz[cur] = 1, val[cur] = v, ls[cur] = rs[cur] = max(v, 0), gss[cur] = sum[cur] = v, lazy[cur] = upd[cur] = 0;
        return cur;
    }
    int build(int l, int r) {
        if (l > r) return 0;
        int mid = (l + r) >> 1, cur = newNode(arr[mid]);
        if (l == r) return cur;
        if ((ch[cur][0] = build(l, mid - 1))) fa[ch[cur][0]] = cur;
        if ((ch[cur][1] = build(mid + 1, r))) fa[ch[cur][1]] = cur;
        pushup(cur);
        return cur;
    }
    void recycle(int x) {
        if (ch[x][0]) recycle(ch[x][0]);
        if (ch[x][1]) recycle(ch[x][1]);
        q.push(x);
    }
    void insert(int x, int gg) {
        int lb = kth(x + 1), rb = kth(x + 2);
        splay(lb), splay(rb, lb);
        fa[gg] = rb, ch[rb][0] = gg; // important
        pushup(rb), pushup(lb);
    }
    void remove(int l, int r) {
        int lb = kth(l), rb = kth(r + 2);
        splay(lb), splay(rb, lb);
        int del = ch[rb][0];
        recycle(del), ch[rb][0] = 0;
        pushup(rb), pushup(lb);
    } 
    void rev(int l, int r) {
        int lb = kth(l), rb = kth(r + 2);
        splay(lb), splay(rb, lb);
        int gg = ch[rb][0];
        if (!upd[gg]) lazy[gg] ^= 1, swap(ch[gg][0], ch[gg][1]), swap(ls[gg], rs[gg]), pushup(rb), pushup(lb);
    }
    void update(int l, int r, int v) {
        int lb = kth(l), rb = kth(r + 2);
        splay(lb), splay(rb, lb);
        int gg = ch[rb][0];
        upd[gg] = 1, val[gg] = v, sum[gg] = v * siz[gg], ls[gg] = rs[gg] = max(0, sum[gg]), gss[gg] = max(v, sum[gg]);
        pushup(rb), pushup(lb);
    }
    int qsum(int l, int r) {
        int lb = kth(l), rb = kth(r + 2);
        splay(lb), splay(rb, lb);
        return sum[ch[rb][0]];
    }
    int qmax() {return gss[rt];}

    void clean() {
        rt = ncnt = 0;
        ms(ch, 0), ms(fa, 0), ms(siz, 0), ms(val, 0), ms(ls, 0), ms(rs, 0), ms(gss, 0), ms(sum, 0), ms(lazy, 0), ms(upd, 0);
        gss[0] = val[0] = -INF;
    }
    int solve() {
        clean();
        scanf("%d%d", &n, &m);
        for (int i = 1; i <= n; ++i) scanf("%d", &arr[i + 1]);
        arr[1] = arr[n + 2] = -INF, rt = build(1, n + 2);
        while (m--) {
            scanf("%s", s);
            if (s[0] == 'I') { // insert
                int pos, tot; scanf("%d%d", &pos, &tot);
                for (int i = 1; i <= tot; ++i) scanf("%d", &arr[i]);
                insert(pos, build(1, tot));
            }
            if (s[0] == 'D') { // delete
                int pos, tot; scanf("%d%d", &pos, &tot);
                remove(pos, pos + tot - 1);
            }
            if (s[0] == 'M' && s[5] == 'S') { // update
                int pos, tot, c; scanf("%d%d%d", &pos, &tot, &c);
                update(pos, pos + tot - 1, c);
            }
            if (s[0] == 'R') { // rev
                int pos, tot; scanf("%d%d", &pos, &tot);
                rev(pos, pos + tot - 1);
            }
            if (s[0] == 'G') { // qsum
                int pos, tot; scanf("%d%d", &pos, &tot);
                printf("%d\n", qsum(pos, pos + tot - 1));
            }
            if (s[0] == 'M' && s[4] == 'S') { // qmax
                printf("%d\n", qmax());
            }
        }
        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------