原题链接:https://www.luogu.com.cn/problem/P3369
题意解读:平衡树的基本操作,模版题。
解题思路:
1、二叉搜索树-BST
二叉搜索树满足这样的性质:每一个节点的权值大于它的左儿子,小于它的右儿子。
对BST进行中序遍历,将得到一个从小到大的有序序列,因此BST是为了维护一个有序序列的动态添加、删除、查找。
随机情况下,对树进行插入、查找、删除等操作的时间复杂度都是O(logN),
但是如果插入顺序是一个已经有序的序列,将退化成一条链,时间复杂度变成O(N)。
2、平衡树
平衡树就是为了解决BST中高度不均衡导致时间复杂度上升的问题,
为了使某个节点左右子树高度尽可能差距小,需要进行两个重要的操作:左旋、右旋
左旋:将以E为根的子树左旋,先令S = E->right,再E->right = S->left,然后S->left = E
右旋:将以S为根的子树右旋,先令E = S->left,再S->left = E->right,然后E->right = S
平衡树的具体实现方式有多种,如AVL、红黑树、Treap、Splay、Trie等等,本文主要介绍最好写的两种:Trie、Treap。
3、用01-Trie平替平衡树
为什么01-Trie可以平替平衡树?
首先,01-Trie高度是固定的,显然满足平衡的特点。
其次,01-Trie也满足左子树对应的值更小,右子树对应的值更大,能够维护序列的有序性。
最后,01-Trie实现平衡树,需要记录一些额外的信息:每个结点所在子树一共有有多个元素
但是,01-Trie作为平衡树也有一些缺点,比如:
占用空间较大,每个整数都拆成二进制作为树的节点。
不能处理负数,但是可以加上一个较大的数将负数转正。
在数据量不太大的情况下,还是可以使用的。
Trie实现平衡树的基本操作:
本题元素大小|x|<=10^7
设int trie[N * 26][2], idx表示Trie树,int siz[N * 26]记录每个节点所在子树的元素个数。
a、插入
void add(int val)
{int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(!trie[u][v]) trie[u][v] = ++idx;u = trie[u][v];siz[u]++;}
}
b、删除
void del(int val)
{int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(!trie[u][v]) return;u = trie[u][v];siz[u]--;}
}
c、查找小于x的元素个数
int get_less(int val)
{int res = 0;int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(v == 1) res += siz[trie[u][0]]; //如果val在右子树,则左子树所有数都是小于val的,要累加u = trie[u][v];if(!u) break; //如果val不存在,到这里就可以结束}return res;
}
d、查找第k个数
int get_kth(int k)
{int res = 0;int u = 0;for(int i = 25; i >= 0; i--){if(siz[trie[u][0]] < k) //左子树数量不足k,在右子树找{k -= siz[trie[u][0]]; //k要减去左子树的数量u = trie[u][1];res = res * 2 + 1;} else {u = trie[u][0];res = res * 2;}if(!u) break;}return res - INF;
}
100分代码:
#include <bits/stdc++.h>
using namespace std;const int N = 100005, INF = 1e7;int trie[N * 26][2], idx;
int siz[N * 26]; //siz[i]表示节点i所在子树中数的个数,根节点不需要记录//插入元素到trie
void add(int val)
{int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(!trie[u][v]) trie[u][v] = ++idx;u = trie[u][v];siz[u]++;}
}//从trie中删除元素
void del(int val)
{int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(!trie[u][v]) return;u = trie[u][v];siz[u]--;}
}//获取小于val的元素数量
int get_less(int val)
{int res = 0;int u = 0;for(int i = 25; i >= 0; i--){int v = val >> i & 1;if(v == 1) res += siz[trie[u][0]];u = trie[u][v];if(!u) break; //如果val不存在,到这里就可以结束}return res;
}//获取排名第k的元素
int get_kth(int k)
{int res = 0;int u = 0;for(int i = 25; i >= 0; i--){if(siz[trie[u][0]] < k) //左子树数量不足k,在右子树找{k -= siz[trie[u][0]];u = trie[u][1];res = res * 2 + 1;} else {u = trie[u][0];res = res * 2;}if(!u) break;}return res - INF;
}int main()
{int n;cin >> n;int opt, x;while(n--){cin >> opt >> x;if(opt == 1) x += INF, add(x); //元素值+INF,使得必然为非负数,才能加入01trieelse if(opt == 2) x += INF, del(x);else if(opt == 3) x += INF, cout << get_less(x) + 1 << endl;else if(opt == 4) cout << get_kth(x) << endl;else if(opt == 5) x += INF, cout << get_kth(get_less(x)) << endl;else x += INF, cout << get_kth(get_less(x + 1) + 1) << endl;}return 0;
}
4、Treap平衡树
Treap是Tree+Heap,也就是树+堆,通过树来维护BST结构,通过堆的性质来保证尽可能平衡。
具体来说,树的节点定义为:
struct Node
{int l, r; //l左子树,r右子树int val, pri; //val是节点权值,pri是随机数用来维护堆的性质int siz, cnt; //siz是节点为根的子树大小,cnt是节点重复元素的个数
} tr[N];
int idx; //树节点编号
int root; //根节点
通过val来维护BST的性质,如果val严格有序将导致树退化成链,因此引入一个随机数pri,并强制父节点的pri大于子节点pri(大根堆性质),通过维护此性质即可保持树的平衡。
通过siz,cnt这些附加信息,就可以实现查元素排名、查第k个元素、找前驱、找后继等操作。
Treap维护树的平衡只需要在插入元素的时候判断,如果插入元素后,导致子节点的pri大于父节点的pri,则进行相应的旋转操作(左旋or右旋)。
100分代码:
#include <bits/stdc++.h>
using namespace std;const int N = 100005, INF = 1e8;struct Node
{int l, r; //l左子树,r右子树int val, pri; //val是节点权值,pri是随机数用来维护堆的性质int siz, cnt; //siz是节点为根的子树大小,cnt是节点重复元素的个数
} tr[N];
int idx; //树节点编号
int root; //根节点//生成一个新节点
int get_node(int val)
{tr[++idx].val = val;tr[idx].pri = rand(); //随机值,通过维护大根堆特性确保尽量平衡tr[idx].siz = tr[idx].cnt = 1;return idx;
}//计算子树siz
void pushup(int &p)
{tr[p].siz = tr[tr[p].l].siz + tr[tr[p].r].siz + tr[p].cnt;
}//右旋
void rotate_to_r(int &p)
{int t = tr[p].l; tr[p].l = tr[t].r;tr[t].r = p;p = t;pushup(tr[p].r);pushup(p);
}//左旋
void rotate_to_l(int &p)
{int t = tr[p].r;tr[p].r = tr[t].l;tr[t].l = p;p = t;pushup(tr[p].l);pushup(p);
} //初始化树
void build_tree()
{ //树中添加两个初始节点:极大值和极小值,避免出现边界问题get_node(-INF); get_node(INF);root = 1;tr[root].r = 2;pushup(root);if(tr[1].pri < tr[2].pri) rotate_to_l(root);
}void insert(int &p, int val)
{if(!p) p = get_node(val);else if(tr[p].val == val) tr[p].cnt++;else if(tr[p].val > val) {insert(tr[p].l, val);if(tr[tr[p].l].pri > tr[p].pri) rotate_to_r(p); //插入左子树后对不满足堆性质进行调整}else{insert(tr[p].r, val);if(tr[tr[p].r].pri > tr[p].pri) rotate_to_l(p); //插入右子树后对不满足堆性质进行调整}pushup(p);
}void erase(int &p, int val)
{if(!p) return;else if(tr[p].val == val){if(tr[p].cnt > 1) tr[p].cnt--; //找到有多个,减一个else if(!tr[p].l && !tr[p].r) //叶子节点,直接删除{p = 0;}else if(!tr[p].r || tr[tr[p].l].pri > tr[tr[p].r].pri) { //如果只有左子树,或者左子树pri大于右子树,则右旋,然后去右子树删除rotate_to_r(p);erase(tr[p].r, val);}else if(!tr[p].l || tr[tr[p].r].pri > tr[tr[p].l].pri){ //如果只有右子树,或者右子树pri大于左子树,则左旋,然后去左子树删除rotate_to_l(p);erase(tr[p].l, val);}}else if(tr[p].val > val) erase(tr[p].l, val);else erase(tr[p].r, val);pushup(p);
}//查询比val小的数的个数,由于第一个节点是-INF,因此比val小的数的个数就是排名
int get_less(int p, int val)
{if(!p) return 0; else if(tr[p].val == val) return tr[tr[p].l].siz; //p就是val,则p左子树大小就是比val小的数的个数else if(tr[p].val > val) return get_less(tr[p].l, val); //到左子树找else if(tr[p].val < val) return tr[p].cnt + tr[tr[p].l].siz + get_less(tr[p].r, val); //到右子树找,p和p的左子树都比val小,要累加
}//查询第k个数
int get_kth(int p, int k)
{if(!p) return 0; //没有找到else if(tr[tr[p].l].siz >= k) return get_kth(tr[p].l, k); //到左子树找else if(tr[tr[p].l].siz + tr[p].cnt >= k) return tr[p].val; //p就是第k个else if(tr[tr[p].l].siz < k) return get_kth(tr[p].r, k - tr[tr[p].l].siz - tr[p].cnt); //到右子树找
}//查找val的前驱,比val小的最大数
int get_prev(int p, int val)
{if(!p) return -INF;else if(tr[p].val >= val) return get_prev(tr[p].l, val);else return max(tr[p].val, get_prev(tr[p].r, val));
}//查找val的后继,比val大的最小数
int get_next(int p, int val)
{if(!p) return INF;else if(tr[p].val <= val) return get_next(tr[p].r, val);else return min(tr[p].val, get_next(tr[p].l, val));
}int main()
{int n;cin >> n;int opt, x;build_tree();while(n--){cin >> opt >> x;if(opt == 1) insert(root, x);else if(opt == 2) erase(root, x);else if(opt == 3) cout << get_less(root, x) << endl;else if(opt == 4) cout << get_kth(root, x + 1) << endl;else if(opt == 5) cout << get_prev(root, x) << endl;else if(opt == 6) cout << get_next(root, x) << endl;}return 0;
}