Segment Tree Beats
区间最值操作
先来一道例题。 这道题要求维护区间和,区间最大值以及区间最值操作。 区间和与区间最大值可以轻易地使用普通线段树维护,但这个区间最值操作是我们要研究的重点。 区间最值操作,说人话就是给出一个值 \(t\),将区间内各元素的值限制在 \(t\) 以下(或以上)。也就是令区间内所有值对 \(t\) 取 \(\max\)(或 \(\min\))。例题是区间 \(\min\)。 对于线段树的每个节点,维护区间和 \(sum\),区间最大值 \(maxn\),区间严格次大值 \(se\) 和最大值个数 \(cnt\)。 接下来,对于区间 \(\min\) 操作 - 若 \(maxn<t\),则显然操作无效,直接退出。 - 若 \(se<t<maxn\),则修改只会影响到区间最大值,令 \(sum\gets sum-cnt \cdot (maxn-t)\),\(maxn \gets t\),打标记然后退出。 - 若 \(se\ge t\),进入左右儿子递归搜索,然后上传信息。
可以证明,此法时间复杂度为 \(O(m\log
n)\)。 证明详见 P104~105。 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
using namespace std;
const int N=1000010;
int t,n,m,a[N];
struct segtree{
int maxn,se,cnt;
long long sum;
}tree[N*4];
int tag[N*4];
void push_up(int u){
tree[u].sum=tree[ls].sum+tree[rs].sum;
tree[u].maxn=max(tree[ls].maxn,tree[rs].maxn);
tree[u].cnt=0;
if(tree[u].maxn==tree[ls].maxn)
tree[u].cnt+=tree[ls].cnt;
if(tree[u].maxn==tree[rs].maxn)
tree[u].cnt+=tree[rs].cnt;
tree[u].se=max(tree[ls].se,tree[rs].se);
if(tree[ls].maxn!=tree[rs].maxn)
tree[u].se=max(tree[u].se,min(tree[ls].maxn,tree[rs].maxn));
return;
}
void build(int u,int l,int r){
tag[u]=-1;
if(l==r){
tree[u].sum=tree[u].maxn=a[l];
tree[u].cnt=1;
tree[u].se=-1;
return;
}
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
return;
}
void update_tag(int u,int k){
if(tree[u].maxn<=k)
return;
tree[u].sum-=(long long)tree[u].cnt*(tree[u].maxn-k);
tree[u].maxn=tag[u]=k;
return;
}
void push_down(int u){
if(tag[u]==-1)
return;
update_tag(ls,tag[u]);
update_tag(rs,tag[u]);
tag[u]=-1;
return;
}
void modify(int u,int l,int r,int x,int y,int k){
if(tree[u].maxn<=k)
return;
if(x<=l&&y>=r&&tree[u].se<k){
update_tag(u,k);
return;
}
int mid=(l+r)/2;
push_down(u);
if(x<=mid)
modify(ls,l,mid,x,y,k);
if(y>mid)
modify(rs,mid+1,r,x,y,k);
push_up(u);
return;
}
int query_max(int u,int l,int r,int x,int y){
if(x<=l&&y>=r)
return tree[u].maxn;
int mid=(l+r)/2,res=-1;
push_down(u);
if(x<=mid)
res=max(res,query_max(ls,l,mid,x,y));
if(y>mid)
res=max(res,query_max(rs,mid+1,r,x,y));
return res;
}
long long query_sum(int u,int l,int r,int x,int y){
if(x<=l&&y>=r)
return tree[u].sum;
int mid=(l+r)/2;
long long res=0;
push_down(u);
if(x<=mid)
res+=query_sum(ls,l,mid,x,y);
if(y>mid)
res+=query_sum(rs,mid+1,r,x,y);
return res;
}
int main(){
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,1,n);
while(m--){
int opt,x,y,t;
scanf("%d%d%d",&opt,&x,&y);
if(opt==0){
scanf("%d",&t);
modify(1,1,n,x,y,t);
}
else if(opt==1)
printf("%d\n",query_max(1,1,n,x,y));
else
printf("%lld\n",query_sum(1,1,n,x,y));
}
}
return 0;
}
区间最值&区间加
例题
同时有区间加后,原本的势能分析就不适用了。复杂度变为 \(O(m\log^2n)\)。对于区间加,区间 \(\max\) 和区间 \(\min\) 各维护一个标记。
然后你就会发现真的很难写。
注意,区间加的标记优先级大于区间最值操作。还有,在进行更新时要考虑到
\(maxn\),\(minn\),\(max\_se\),\(min\_se\),\(min\_tag\) 和 \(max\_tag\) 相互重叠的情况。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
using namespace std;
const int N=5e5+10;
int n,a[N],m;
struct segment_tree{
long long sum;
int maxn,minn,max_se,min_se,max_cnt,min_cnt,add_tag,max_tag,min_tag;
}tree[N*4];
void push_up(int u){
tree[u].sum=tree[ls].sum+tree[rs].sum;
if(tree[ls].maxn==tree[rs].maxn){
tree[u].maxn=tree[ls].maxn;
tree[u].max_cnt=tree[ls].max_cnt+tree[rs].max_cnt;
tree[u].max_se=max(tree[ls].max_se,tree[rs].max_se);
}
else if(tree[ls].maxn<tree[rs].maxn){
tree[u].maxn=tree[rs].maxn;
tree[u].max_cnt=tree[rs].max_cnt;
tree[u].max_se=max(tree[ls].maxn,tree[rs].max_se);
}
else{
tree[u].maxn=tree[ls].maxn;
tree[u].max_cnt=tree[ls].max_cnt;
tree[u].max_se=max(tree[rs].maxn,tree[ls].max_se);
}
if(tree[ls].minn==tree[rs].minn){
tree[u].minn=tree[ls].minn;
tree[u].min_cnt=tree[ls].min_cnt+tree[rs].min_cnt;
tree[u].min_se=min(tree[ls].min_se,tree[rs].min_se);
}
else if(tree[ls].minn<tree[rs].minn){
tree[u].minn=tree[ls].minn;
tree[u].min_cnt=tree[ls].min_cnt;
tree[u].min_se=min(tree[ls].min_se,tree[rs].minn);
}
else{
tree[u].minn=tree[rs].minn;
tree[u].min_cnt=tree[rs].min_cnt;
tree[u].min_se=min(tree[rs].min_se,tree[ls].minn);
}
return;
}
void build(int u,int l,int r){
tree[u].max_tag=-inf,tree[u].min_tag=inf;
if(l==r){
tree[u].sum=tree[u].maxn=tree[u].minn=a[l];
tree[u].max_cnt=tree[u].min_cnt=1;
tree[u].max_se=-inf;
tree[u].min_se=inf;
return;
}
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
push_up(u);
return;
}
void update_add(int u,int l,int r,int k){
tree[u].sum+=(r-l+1ll)*k;
tree[u].maxn+=k,tree[u].minn+=k,tree[u].add_tag+=k;
if(tree[u].max_se!=-inf)
tree[u].max_se+=k;
if(tree[u].min_se!=inf)
tree[u].min_se+=k;
if(tree[u].max_tag!=-inf)
tree[u].max_tag+=k;
if(tree[u].min_tag!=inf)
tree[u].min_tag+=k;
return;
}
void update_min_tag(int u,int k){
if(tree[u].maxn<=k)
return;
tree[u].sum+=(1ll*k-tree[u].maxn)*tree[u].max_cnt;
if(tree[u].min_se==tree[u].maxn)
tree[u].min_se=k;
if(tree[u].minn==tree[u].maxn)
tree[u].minn=k;
if(tree[u].max_tag>k)
tree[u].max_tag=k;
tree[u].maxn=k,tree[u].min_tag=k;
return;
}
void update_max_tag(int u,int k){
if(tree[u].minn>=k)
return;
tree[u].sum+=(1ll*k-tree[u].minn)*tree[u].min_cnt;
if(tree[u].max_se==tree[u].minn)
tree[u].max_se=k;
if(tree[u].maxn==tree[u].minn)
tree[u].maxn=k;
if(tree[u].min_tag<k)
tree[u].min_tag=k;
tree[u].minn=k,tree[u].max_tag=k;
return;
}
void push_down(int u,int l,int r){
int mid=(l+r)/2;
if(tree[u].add_tag)
update_add(ls,l,mid,tree[u].add_tag),update_add(rs,mid+1,r,tree[u].add_tag);
if(tree[u].max_tag!=-inf)
update_max_tag(ls,tree[u].max_tag),update_max_tag(rs,tree[u].max_tag);
if(tree[u].min_tag!=inf)
update_min_tag(ls,tree[u].min_tag),update_min_tag(rs,tree[u].min_tag);
tree[u].add_tag=0,tree[u].max_tag=-inf,tree[u].min_tag=inf;
return;
}
void modify_add(int u,int l,int r,int x,int y,int k){
if(x<=l&&y>=r){
update_add(u,l,r,k);
return;
}
int mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
modify_add(ls,l,mid,x,y,k);
if(y>mid)
modify_add(rs,mid+1,r,x,y,k);
push_up(u);
return;
}
void modify_max(int u,int l,int r,int x,int y,int k){
if(x<=l&&y>=r&&tree[u].min_se>k){
update_max_tag(u,k);
return;
}
int mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
modify_max(ls,l,mid,x,y,k);
if(y>mid)
modify_max(rs,mid+1,r,x,y,k);
push_up(u);
return;
}
void modify_min(int u,int l,int r,int x,int y,int k){
if(x<=l&&y>=r&&tree[u].max_se<k){
update_min_tag(u,k);
return;
}
int mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
modify_min(ls,l,mid,x,y,k);
if(y>mid)
modify_min(rs,mid+1,r,x,y,k);
push_up(u);
return;
}
long long query_sum(int u,int l,int r,int x,int y){
if(x<=l&&y>=r)
return tree[u].sum;
long long res=0;
int mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
res+=query_sum(ls,l,mid,x,y);
if(y>mid)
res+=query_sum(rs,mid+1,r,x,y);
return res;
}
int query_max(int u,int l,int r,int x,int y){
if(x<=l&&y>=r)
return tree[u].maxn;
int res=-inf,mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
res=max(res,query_max(ls,l,mid,x,y));
if(y>mid)
res=max(res,query_max(rs,mid+1,r,x,y));
return res;
}
int query_min(int u,int l,int r,int x,int y){
if(x<=l&&y>=r)
return tree[u].minn;
int res=inf,mid=(l+r)/2;
push_down(u,l,r);
if(x<=mid)
res=min(res,query_min(ls,l,mid,x,y));
if(y>mid)
res=min(res,query_min(rs,mid+1,r,x,y));
return res;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",a+i);
build(1,1,n);
scanf("%d",&m);
int opt,l,r,x;
while(m--){
scanf("%d%d%d",&opt,&l,&r);
if(opt==1){
scanf("%d",&x);
modify_add(1,1,n,l,r,x);
}
else if(opt==2){
scanf("%d",&x);
modify_max(1,1,n,l,r,x);
}
else if(opt==3){
scanf("%d",&x);
modify_min(1,1,n,l,r,x);
}
else if(opt==4)
printf("%lld\n",query_sum(1,1,n,l,r));
else if(opt==5)
printf("%d\n",query_max(1,1,n,l,r));
else
printf("%d\n",query_min(1,1,n,l,r));
}
return 0;
}
区间历史最值
例题
这里的“历史”不同于可持久化,以历史最大值为例,我们称原数组为 \(A\),定义数组 \(B\),则 \(B_i\) 表示所有历史版本中最大的那个 \(A_i\),形式化地,每次操作之后,都令 \(B_i=\max(A_i,B_i)\)。初始时 \(B\) 与 \(A\) 相同。 1
//代码先咕着