TheSky233's Blog

𝐴𝑛𝑑 𝑖𝑛 𝑡ℎ𝑎𝑡 𝑙𝑖𝑔ℎ𝑡...

0%

【学习笔记】平衡树Splay

【学习笔记】Splay

$\texttt{0x01}$ 前言

Splay 树(伸展树)由 Daniel SleatorRobert Tarjan 于 1985 年发明。它凭借旋转可以有 $O(\log n) $ 插入,删除等的较优秀的时间复杂度。

前置芝士:普通二叉排序树

推荐博客:

$\texttt{0x02}$ 如何构造一棵 Splay

我们定义一个结构体:

1
2
3
4
5
6
7
8
9
10
11
#define val(x) t[x].val
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define son(x,nxt) t[x].ch[nxt]
#define fa(x) t[x].fa
#define cnt(x) t[x].cnt
#define siz(x) t[x].siz
struct node{
int val,fa,ch[2],siz,cnt;
}t[N];
int root,tot;

其中构造一个新节点的函数长这样:

1
2
3
4
5
void newPoint(int val,int fa,int nxt){ //值为val,父节点为fa,为fa的nxt儿子
tot++;
fa(tot)=fa; cnt(tot)=siz(tot)=1; val(tot)=val;
son(fa,nxt)=tot;
}

$\texttt{0x03}$ which / pushup / connect

which 的作用是判断 $x$ 是其父节点的左节点($0$)还是右节点($1$),代码很好写:

1
2
3
bool which(int x){
return rs(fa(x))==x;
}

pushup 的作用是维护当前节点的 $siz$ 信息,和线段树的 pushup 性质差不多,代码:

1
2
3
void pushup(int x){
siz(x)=siz(ls(x))+siz(rs(x))+cnt(x);//记得加上当前节点的cnt
}

connect 的作用是把 $x$ 变成 $y$ 的 $nxt$ 儿子,无需考虑覆盖的问题,代码也很简洁:

1
2
3
4
void connect(int x,int y,int nxt){
son(y,nxt)=x;
fa(x)=y;
}

$\texttt{0x04}$ rotate

Splay 的核心操作:旋转。

放两张动图:

我们会发现:右旋时,E 节点要到 S 节点的位置上,那么 E 节点的右儿子因为它 $\ge E$ 且 $\le S$,所以只能放在 S 节点的左儿子,然后要改变 E 和 S 的父子关系。最后别忘了因为有旋转,所以要自下而上更新节点信息。

左旋同理。

代码:

1
2
3
4
5
6
7
8
9
10
void rotate(int x){
int y=fa(x),z=fa(y);
int fx=which(x),fy=which(y);

connect(son(x,fx^1),y,fx); //如果x是左儿子,改变它右儿子的位置,反之同理
connect(y,x,fx^1); //把y接到x的缺失的那一棵子树上
connect(x,z,fy); //把x接到y的父节点上去

pushup(y); pushup(x); //别搞错顺序
}

$\texttt{0x05}$ splay

Splay 树保证时间复杂度正确的核心操作,把 $x$ 转到 $y$ 的位置($y$ 通常为 $root$)。

有几点结论,难证但好记:

  • 若 $fa(x)=y$,则单旋 $x$。
  • 若 $x$、$fa(x)$、$fa\left(fa(x)\right)$ 不在一条线上,则先单旋 $fa(x)$,再单旋 $x$。
  • 否则旋转两次 $x$。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void splay(int x,int y){
y=fa(y); //避免x=y时出现的错误
while(fa(x)!=y){
if(fa(fa(x))==y) // Case 1
rotate(x);
else if(which(x)==which(fa(x))) // Case 2
rotate(fa(x)), rotate(x);
else // Case 3
rotate(x), rotate(x);
}
if(y==0){ // 如果y是根,把根变为x
root=x;
connect(x,0,1);
}
}

$\texttt{0x06}$ insert

与普通的二叉排序树基本一致。

  • 如果树中已经有值了,则 $cnt \gets cnt+1$。
  • 如果找到最后都没有值,建个新节点。

记得最后要 splay 一下,把这个点转到根节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void insert(int val){
if(root==0){
newPoint(val,0,1);
root=tot;
return;
}
int now=root;
while(1){
siz(now)++;
if(val(now)==val){
cnt(now)++;
splay(now,root);
return;
}
int nxt=val(now)<val, son=son(now,nxt);
if(!son){
newPoint(val,now,nxt);
splay(tot,root);
return;
}
now=son;
}
}

$\texttt{0x07}$ find

这一步操作是找到树中值为 $val$ 的节点,并把它旋转到根节点,为 delete 操作做准备。

与普通二叉排序树也基本一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
int find(int val){
int now=root;
while(1){
if(!now)
return 0;
if(val(now)==val){
splay(now,root);
return now;
}
int nxt=val(now)<val, son=son(now,nxt);
now=son;
}
}

$\texttt{0x08}$ delete

目的是删除树中值为 $val$ 的节点。

先 find 这个节点,让他转到根,然后分类讨论。

  1. 树中没有值为 $val$ 的节点,删了个寂寞。

  2. 树中值为 $val$ 的节点有不止一个(即 $cnt \ge 2$),让 $cnt \gets cnt-1$ 即可。

  3. 这个节点没有左儿子(即根节点只有右子树),把右儿子设为根就行了。

  4. 这个节点没有右儿子(即根节点只有左子树),把左儿子设为根就行了。

  5. 这个节点(设为 $x$)有左右儿子,把它的左子树中值最大的(设为 $y$)splay 到根,然后现在的 Splay 树的根就是 $y$,左子树是原来的除 $y$ 之外的左子树,右子树是 $x$ 和之前的右子树。把之前的右子树 connect 到根就行了。

注意:这里的删除操作都没有回收编号

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void delet(int val){
int now=find(val);
if(!now) return;
if(cnt(now)>1){
cnt(now)--; siz(now)--;
return;
}
if(!ls(now) && !rs(now)){
root=0;
}
else if(!ls(now)){
root=rs(root);
fa(root)=0;
}
else if(!rs(now)){
root=ls(root);
fa(root)=0;
}
else{
int pos=ls(now);
while(rs(pos)) pos=rs(pos);

splay(pos,root);
connect(rs(now),pos,1);
pushup(pos);
}
}

$\texttt{0x09}$ rnk & find_k

rnk 是返回值为 $val$ 的数在树中的排名,find_k 是找到树中排名为 k 的数。

与二叉排序树基本完全相同。记得最后要把节点 splay 到根。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
int rnk(int val){
int now=root,s=0;
while(now){
if(val(now)==val){
splay(now,root);
return siz(ls(now))+1;
}
if(val(now)<val){
s+=siz(ls(now))+cnt(now);
now=rs(now);
}
else{
now=ls(now);
}
}
return s+1;
}
int find_k(int k){
int now=root;
while(1){
int used=siz(now)-siz(rs(now));
if(k>siz(ls(now)) && k<=used){
break;
}
if(k>=used){
k-=used;
now=rs(now);
}
else{
now=ls(now);
}
}
splay(now,root);
return val(now);
}

$\texttt{0x0A}$ lower & upper

返回值为 $val$ 的数的前驱和后继。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int lower(int val){
int ans=-2147483647;
int now=root;
while(now){
if(val(now)<val && val(now)>ans){
ans=val(now);
}
if(val>val(now)){
now=rs(now);
}
else{
now=ls(now);
}
}
return ans;
}
int upper(int val){
int ans=2147483647;
int now=root;
while(now){
if(val(now)>val && val(now)<ans){
ans=val(now);
}
if(val<val(now)){
now=ls(now);
}
else{
now=rs(now);
}
}
return ans;
}

$\texttt{0x0B}$ 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#include<bits/stdc++.h>
using namespace std;
constexpr int N = 5e5+5;

template <typename T> void read(T &x){x=0; T f(0); char ch=getchar(); while(ch<'0'||ch>'9'){f|=ch=='-';ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=getchar();} x=f?-x:x;}
template <typename T,typename ...Arg>void read(T& x,Arg& ...arg){read(x);read(arg...);}
template <typename T> inline void write(T x){static char buf[64]; static int tot(0); if(x<0) putchar('-'),x=-x; do buf[++tot]=(x%10)+48,x/=10; while(x); do putchar(buf[tot--]); while(tot);}
template <typename T> void write(T x,char c){static char buf[64]; static int tot(0); if(x<0) putchar('-'),x=-x; do buf[++tot]=(x%10)+48,x/=10; while(x); do putchar(buf[tot--]); while(tot); putchar(c);}

class Splay{
#define val(x) t[x].val
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define son(x,nxt) t[x].ch[nxt]
#define fa(x) t[x].fa
#define cnt(x) t[x].cnt
#define siz(x) t[x].siz
private:
struct node{
int val,fa,ch[2],siz,cnt;
}t[N];
int root,tot;
public:
bool which(int x){
return rs(fa(x))==x;
}
void pushup(int x){
siz(x)=siz(ls(x))+siz(rs(x))+cnt(x);
}
void connect(int x,int y,int nxt){
son(y,nxt)=x;
fa(x)=y;
}
void rotate(int x){
int y=fa(x),z=fa(y);
int fx=which(x),fy=which(y);

connect(son(x,fx^1),y,fx);
connect(y,x,fx^1);
connect(x,z,fy);

pushup(y); pushup(x);
}
void splay(int x,int y){
y=fa(y);
while(fa(x)!=y){
if(fa(fa(x))==y)
rotate(x);
else if(which(x)==which(fa(x)))
rotate(fa(x)), rotate(x);
else
rotate(x), rotate(x);
}
if(y==0){
root=x;
connect(x,0,1);
}
}
void newPoint(int val,int fa,int nxt){
tot++;
fa(tot)=fa; cnt(tot)=siz(tot)=1; val(tot)=val;
son(fa,nxt)=tot;
}
void insert(int val){
if(root==0){
newPoint(val,0,1);
root=tot;
return;
}
int now=root;
while(1){
siz(now)++;
if(val(now)==val){
cnt(now)++;
splay(now,root);
return;
}
int nxt=val(now)<val, son=son(now,nxt);
if(!son){
newPoint(val,now,nxt);
splay(tot,root);
return;
}
now=son;
}
}
int find(int val){
int now=root;
while(1){
if(!now)
return 0;
if(val(now)==val){
splay(now,root);
return now;
}
int nxt=val(now)<val, son=son(now,nxt);
now=son;
}
}
void delet(int val){
int now=find(val);
if(!now) return;
if(cnt(now)>1){
cnt(now)--; siz(now)--;
return;
}
if(!ls(now) && !rs(now)){
root=0;
}
else if(!ls(now)){
root=rs(root);
fa(root)=0;
}
else if(!rs(now)){
root=ls(root);
fa(root)=0;
}
else{
int pos=ls(now);
while(rs(pos)) pos=rs(pos);

splay(pos,root);
connect(rs(now),pos,1);
pushup(pos);
}
}
int rnk(int val){
int now=root,s=0;
while(now){
if(val(now)==val){
splay(now,root);
return siz(ls(now))+1;
}
if(val(now)<val){
s+=siz(ls(now))+cnt(now);
now=rs(now);
}
else{
now=ls(now);
}
}
return s+1;
}
int find_k(int k){
int now=root;
while(1){
int used=siz(now)-siz(rs(now));
if(k>siz(ls(now)) && k<=used){
break;
}
if(k>=used){
k-=used;
now=rs(now);
}
else{
now=ls(now);
}
}
splay(now,root);
return val(now);
}
int lower(int val){
int ans=-2147483647;
int now=root;
while(now){
if(val(now)<val && val(now)>ans){
ans=val(now);
}
if(val>val(now)){
now=rs(now);
}
else{
now=ls(now);
}
}
return ans;
}
int upper(int val){
int ans=2147483647;
int now=root;
while(now){
if(val(now)>val && val(now)<ans){
ans=val(now);
}
if(val<val(now)){
now=ls(now);
}
else{
now=rs(now);
}
}
return ans;
}
}tr;

int n,opt,val;

int main(){
read(n);
while(n--){
read(opt,val);
int ans;
switch(opt){
case 1:{
tr.insert(val);
break;
}
case 2:{
tr.delet(val);
break;
}
case 3:{
ans=tr.rnk(val);
break;
}
case 4:{
ans=tr.find_k(val);
break;
}
case 5:{
ans=tr.lower(val);
break;
}
case 6:{
ans=tr.upper(val);
break;
}
}
if(opt>2) write(ans,'\n');
}
}