树状数组 学习笔记

模板及讲解

维护区间$[1, n]$的数据结构,可以用$[1, b]-[1, a]$来求$[a,b]$
可以$add(a, x), add(b+1, -x)$实现区间修改单点查询(差分序列求前缀和优化到$logn$)

二维树状数组(容斥原理)

区间修改区间查询
设原数组为$a_i​$,原数组差分序列为$d_i​$,$x​$为查询区间$[1,x]​$,则
$$a_x=\sum_{i=1}^x d_i$$

$$\sum_{i=1}^x a_i= \sum_{i=1}^x \sum_{j=1}^i d_j =\sum_{i=1}^x(x-i+1)d_i$$
那么
$$\sum_{i=1}^x a_i=(x+1)\sum_{i=1}^x d_i-\sum_{i=1}^x d_i \times i$$
这样我们维护两个树状数组,一个维护$d_i$,一个维护$d_i \times i$,每次查询修改对两个树状数组进行操作即可。(常数比线段树小)

常见题型

1、单点修改区间查询/区间修改单点查询
解:直接套用模板即可,见下面的相关代码
2、开多棵树状数组解决问题
Q:一个区间(矩阵)有多种颜色,每个点有一个权值,每次修改(查询)指定颜色上的权值
解:对于每一种颜色(类型)都开一棵树状数组。
例题:BZOJ 1452
3、二维树状数组
Q:在矩阵上查询某个子矩阵的值。
解:建二维树状数组,见模板讲解。
例题:BZOJ 1452
4、求逆序对
Q:求逆序对。
解:类似权值线段树,每次使$[1, i] + 1$, 然后$i$对答案的贡献为$[1, i]$的值(即小于i的元素个数)
例题:NOIP2013 D1T2
5、区间修改区间查询
解:推公式,开两个树状数组求值,见模板讲解。
例题:BZOJ 2017-07-20集训-t2
5、树状数组离线排序右端点
解:离线,删点加点
例题:spoj DQUERY

相关代码

1 点修改,求x~y区间值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#define ms(i,j) memset(i,j, sizeof i);
using namespace std;
int a[500005];
int n,m;
int abss(int x){return x>=0 ? x : -x;}
int lowbit(int x)
{
return x&(-x);
}
int getsum(int x)//求1~x的和
{
int ret = 0;
for (int i=x;i>0;i-=lowbit(i))
{
ret += a[i];
}
return ret;
}
void addsum(int x, int y)//1~x加y
{
for (int i=x;i<=n;i+=lowbit(i))
{
a[i] += y;
}
}
int main()
{
a[0] = 0;
scanf("%d%d", &n, &m);
for (int i=1;i<=n;i++)
{
int x;
scanf("%d", &x);
addsum(i,x);
}
for (int i=1;i<=m;i++)
{
int ty;
scanf("%d", &ty);
if (ty==1)
{
int x,k;
scanf("%d%d", &x, &k);
addsum(x,k);
} else
{
int x,y;
scanf("%d%d", &x, &y);
printf("%d\n", abss(getsum(y)-getsum(x-1)));
}
}
return 0;
}

2 区间修改,求某一点值(差分序列)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define ms(i,j) memset(i,j,sizeof i);
int n,m;
const int maxn = 500005;
int a[maxn];//a记录的是比i-lowbit(i)多的值
int lowbit(int x)
{
return x&(-x);
}
int add(int x, int v)
{
for (int i=x;i<=n;i+=lowbit(i))
{
a[i] += v;
}
}
int sub(int x)
{
int ret = 0;
for (int i=x;i>0;i-=lowbit(i))
{
ret += a[i];
}
return ret;
}
int main()
{
scanf("%d%d", &n ,&m);
ms(a,0);
for (int i=1;i<=n;i++)
{
int x;
scanf("%d", &x);
add(i,x);
add(i+1,-x);
}
for (int i=1;i<=m;i++)
{
int ty;
scanf("%d", &ty);
if(ty==1)
{
int x,y,k;
scanf("%d%d%d", &x,&y,&k);
add(x,k); add(y+1,-k);
} else
{
int x;
scanf("%d", &x);
printf("%d\n", sub(x));
}
}
system("pause");
return 0;
}

------ 本文结束 ------