树套树
顾名思义,就是一个树套着一个树。
例如:线段树套平衡树,线段树中的每个节点的区间用平衡树维护。
常用:
- 外层:线段树,树状数组
- 内层:平衡树,线段树。(一般可以用
STL
)
例题:
-
AcWing 2488
没啥好说的,线段树套 set
#include <bits/stdc++.h> using namespace std;const int N = 50005, M = N << 2; const int INF = 0x3f3f3f3f; int n, m; struct Tree{int l, r;multiset<int> s; } tr[M]; int w[N];void build(int u,int l,int r){tr[u] = {l, r};tr[u].s.insert(-INF), tr[u].s.insert(INF);for (int i = l; i <= r;i++)tr[u].s.insert(w[i]);int mid = l + r >> 1;if(l==r)return;build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); }void change(int u,int p,int x){tr[u].s.erase(tr[u].s.find(w[p]));tr[u].s.insert(x);if(tr[u].l==tr[u].r)return;int mid = tr[u].l + tr[u].r >> 1;if(p<=mid)change(u << 1, p, x);elsechange(u << 1 | 1, p, x); }int query(int u,int a,int b,int x){if(tr[u].l>=a&&tr[u].r<=b){auto it = tr[u].s.lower_bound(x);--it;return *it;}int mid = tr[u].l + tr[u].r >> 1;int res = -INF;if(a<=mid)res = max(res, query(u << 1, a, b, x));if(b>mid)res = max(res, query(u << 1 | 1, a, b, x));return res; } int main(){cin >> n >> m;for (int i = 1; i <= n;i++)cin >> w[i];build(1, 1, n);while(m--){int op, a, b, x;cin >> op;if(op==x){cin >> a >> x;change(1, a, x);w[a] = x;}else{cin >> a >> b >> x;cout << query(1, a, b, x) << endl;}}return 0; }
-
P3380 【模板】树套树
#include <bits/stdc++.h> using namespace std;const int N = 2000005, INF = 2147483647; int n, m; struct Node{int s[2], p, v;int sz;void init(int _v,int _p){v = _v, p = _p;sz = 1;} } tr[N]; int L[N], R[N], T[N], idx; int w[N];void pushup(int x){tr[x].sz = tr[tr[x].s[0]].sz + tr[tr[x].s[1]].sz + 1; }void rotate(int x){int y = tr[x].p, z = tr[y].p;int k = tr[y].s[1] == x;tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;tr[x].s[k ^ 1] = y, tr[y].p = x;pushup(y), pushup(x); }void splay(int &root,int x,int k){while(tr[x].p!=k){int y = tr[x].p, z = tr[y].p;if(z!=k)if((tr[y].s[1]==x)^(tr[z].s[1]==y))rotate(x);elserotate(y);rotate(x);}if(!k)root = x; }void insert(int &root,int v){int u = root, p = 0;while(u)p = u, u = tr[u].s[v > tr[u].v];u = ++idx;if(p)tr[p].s[v > tr[p].v] = u;tr[u].init(v, p);splay(root, u, 0); }int get_k(int root,int v){int u = root, res = 0;while(u){if(tr[u].v<v)res += tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];elseu = tr[u].s[0];}return res; }void update(int &root,int x,int y){int u = root;while(u){if(tr[u].v==x)break;if(tr[u].v<x)u = tr[u].s[1];elseu = tr[u].s[0];}splay(root, u, 0);int l = tr[u].s[0], r = tr[u].s[1];while(tr[l].s[1])l = tr[l].s[1];while(tr[r].s[0])r = tr[r].s[0];splay(root, l, 0), splay(root, r, l);tr[r].s[0] = 0;pushup(l), pushup(r);insert(root, y); }void build(int u,int l,int r){L[u] = l, R[u] = r;insert(T[u], INF), insert(T[u], -INF);for (int i = l; i <= r;i++)insert(T[u], w[i]);if(l==r)return;int mid = l + r >> 1;build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); }int query(int u,int a,int b,int x){if(L[u]>=a&&R[u]<=b)return get_k(T[u], x) - 1;int mid = L[u] + R[u] >> 1;int res = 0;if(a<=mid)res += query(u<<1, a, b, x);if(b>mid)res += query(u << 1 | 1, a, b, x);return res; }void change(int u,int p,int x){update(T[u], w[p], x);if(L[u]==R[u])return;int mid = L[u] + R[u] >> 1;if(p<=mid)change(u << 1, p, x);elsechange(u << 1 | 1, p, x); }int get_pre(int root,int v){int u = root, res = -INF;while(u){if(tr[u].v<v)res = max(res, tr[u].v), u = tr[u].s[1];elseu = tr[u].s[0];}return res; }int get_suc(int root,int v){int u = root, res = INF;while(u){if(tr[u].v>v)res = min(res, tr[u].v), u = tr[u].s[0];elseu = tr[u].s[1];}return res; } int query_pre(int u,int a,int b,int x){if(L[u]>=a&&R[u]<=b)return get_pre(T[u], x);int mid = L[u] + R[u] >> 1;int res = -INF;if(a<=mid)res = max(res, query_pre(u << 1, a, b, x));if(b>mid)res = max(res, query_pre(u << 1 | 1, a, b, x));return res; }int query_suc(int u,int a,int b,int x){if(L[u]>=a&&R[u]<=b)return get_suc(T[u], x);int mid = L[u] + R[u] >> 1;int res = INF;if(a<=mid)res = min(res, query_suc(u << 1, a, b, x));if(b>mid)res = min(res, query_suc(u << 1 | 1, a, b, x));return res; }int main(){cin >> n >> m;for (int i = 1; i <= n;i++)cin >> w[i];build(1, 1, n);while(m--){int op, a, b, x;cin >> op;if(op==1){cin >> a >> b >> x;cout << query(1, a, b, x) +1 << endl;}else if(op==2){cin >> a >> b >> x;int l = 0, r = 1e8;while(l<r){int mid = l + r + 1 >> 1;if(query(1,a,b,mid)+1<=x)l = mid;elser = mid - 1;}cout << r << endl;}else if(op==3){cin >> a >> x;change(1, a, x);w[a] = x;}else if(op==4){cin >> a >> b >> x;cout << query_pre(1, a, b, x) << endl;}else{cin >> a >> b >> x;cout << query_suc(1, a, b, x) << endl;}}return 0; }
-
P3332 [ZJOI2013] K大数查询
考虑值域线段树套线段树。
Tips:标记持久化,动态开店线段树。
#include <iostream> #include <cstring> #include <cstdio> #include <algorithm> #include <vector>using namespace std;typedef long long LL;const int N = 50010, P = N * 17 * 17, M = N * 4;int n, m; struct Tree {int l, r;LL sum, add; }tr[P]; int L[M], R[M], T[M], idx; struct Query {int op, a, b, c; }q[N]; vector<int> nums;int get(int x) {return lower_bound(nums.begin(), nums.end(), x) - nums.begin(); }void build(int u, int l, int r) {L[u] = l, R[u] = r, T[u] = ++ idx;if (l == r) return;int mid = l + r >> 1;build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); }int intersection(int a, int b, int c, int d) {return min(b, d) - max(a, c) + 1; }void update(int u, int l, int r, int pl, int pr) {tr[u].sum += intersection(l, r, pl, pr);if (l >= pl && r <= pr){tr[u].add ++ ;return;}int mid = l + r >> 1;if (pl <= mid){if (!tr[u].l) tr[u].l = ++ idx;update(tr[u].l, l, mid, pl, pr);}if (pr > mid){if (!tr[u].r) tr[u].r = ++ idx;update(tr[u].r, mid + 1, r, pl, pr);} }void change(int u, int a, int b, int c) {update(T[u], 1, n, a, b);if (L[u] == R[u]) return;int mid = L[u] + R[u] >> 1;if (c <= mid) change(u << 1, a, b, c);else change(u << 1 | 1, a, b, c); }LL get_sum(int u, int l, int r, int pl, int pr, int add) {if (l >= pl && r <= pr) return tr[u].sum + (r - l + 1LL) * add;int mid = l + r >> 1;LL res = 0;add += tr[u].add;if (pl <= mid){if (tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);else res += intersection(l, mid, pl, pr) * add;}if (pr > mid){if (tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);else res += intersection(mid + 1, r, pl, pr) * add;}return res; }int query(int u, int a, int b, int c) {if (L[u] == R[u]) return R[u];int mid = L[u] + R[u] >> 1;LL k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);if (k >= c) return query(u << 1 | 1, a, b, c);return query(u << 1, a, b, c - k); }int main() {scanf("%d%d", &n, &m);for (int i = 0; i < m; i ++ ){scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c);if (q[i].op == 1) nums.push_back(q[i].c);}sort(nums.begin(), nums.end());nums.erase(unique(nums.begin(), nums.end()), nums.end());build(1, 0, nums.size() - 1);for (int i = 0; i < m; i ++ ){int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;if (op == 1) change(1, a, b, get(c));else printf("%d\n", nums[query(1, a, b, c)]);}return 0; }