简介
首先要知道 \(KD-Tree\) 是干什么的,它最广泛的用法便是维护 \(k\) 维最近点对(大部分时候是二维)。
先来讲没有插入,直接建树的。
它的每个结点维护这样子的数据,其中 \(lc\) 和 \(rc\) 代表左右儿子,\(v[i]\) 代表第 \(i\) 维当前点的取值,\(L[i]\) 和 \(U[i]\) 分别代表第 \(i\) 维上当前点对应的子树中所有点的范围,它其实对应了一个矩形。
比如一棵子树中根为 \((114514,114514)\),左儿子为 \((1919, 1919)\) 右儿子为 \((810,810)\),那么对于根来说 \(L[1] = 810, R[1] = 114514\)。
struct KD {int lc, rc;double v[2], L[2], U[2];bool operator < (const KD &t) const {return v[K] < t.v[K];}
} tr[N];
然后是建树过程,其实和替罪羊树的重构过程很像,都是拍扁再拎起来。具体地我们使用 nth_element
函数,将一排数根据 \(mid\) 分成两半。同时,为了保证复杂度,我们要轮流对第 \(k\) 维排序,即 \(0,1\) 循环,上文的比较函数也是为了这个所定义的。值得注意的是 \(k\) 是函数内的 \(K\) 是一个全局变量。
pushup
操作也比较简朴,看看就懂了,就是用儿子更新父亲。
void pushup(int u) {rep(i, 0, 1) {tr[u].L[i] = tr[u].U[i] = tr[u].v[i];if (tr[u].lc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);} if (tr[u].rc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);}}
}
int build(int l, int r, int k) {if (l > r) return 0;int mid = l + r >> 1;K = k; nth_element(tr + l, tr + mid, tr + r);tr[mid].lc = build(l, mid - 1, k ^ 1);tr[mid].rc = build(mid + 1, r, k ^ 1);pushup(mid);return mid;
}
查询也不算非常困难,我们从根开始,分别计算要查的结点 \(cur\) 到根的距离,以及到左右儿子所在范围的最近距离。由于上文的 \(L\) 和 \(U\) 在两维状态下可以看做是一个矩形,所以相当于一个点到矩形的最短距离。可以看看代码画图理解一下。
回到 query
中,我们得知 \(dist\) 后,可以贪心地在左右儿子中选择 \(dist\) 小的来优先更新,然后在考虑另一侧。同时左右边的最优答案一定得小于当前全局最优值,否则不用更新。
inline double sq(double x) {return x * x;
}
inline double dis1(int x) {double res = 0;rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);return res;
}
inline double dis2(int x) {if (!x) return 2e18;double res = 0;rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));return res;
}
void query(int u) {if (!u) return;if (u != cur) ans = min(ans, dis1(u));double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);if (d1 < d2) {if (d1 < ans) query(tr[u].lc);if (d2 < ans) query(tr[u].rc);} else {if (d2 < ans) query(tr[u].rc);if (d1 < ans) query(tr[u].lc);}
}
经过证明(我不会),在处理二维时的复杂度是根号的,\(build\) 是 \(O(nlogn)\),\(k\) 维是 \(O(n^{1 - \frac{1}{k}})\) 的。
当然你也可以动态差点不 \(build\),同样类似于替罪羊数当 \(A * sz[root] \geq max(sz[lc], sz[rc])\) 时就直接重构。\(A\) 我一般取 \(0.7\)。
bool cmp(int a, int b) {return tr[a].v[K] < tr[b].v[K];
}
int rebuild(int l, int r, int k) {if (l > r) return 0;int mid = l + r >> 1;K = k; nth_element(g + l, g + mid, g + r + 1, cmp);tr[g[mid]].lc = rebuild(l, mid - 1, k ^ 1);tr[g[mid]].rc = rebuild(mid + 1, r, k ^ 1);pushup(g[mid]);return g[mid];
}
void dfs(int u) {if (!u) return;g[++ cnt] = u;dfs(tr[u].lc);dfs(tr[u].rc);
}
void check(int &u, int k) {if (tr[u].sz * A < max(tr[tr[u].lc].sz, tr[tr[u].rc].sz)) cnt = 0, dfs(u), u = rebuild(1, cnt, k);
}
void insert(int &u, int k) {if (!u) { u = cur; pushup(u); return; }insert(tr[cur].v[k] <= tr[u].v[k] ? tr[u].lc : tr[u].rc, k ^ 1);pushup(u);check(u, k);
}
模板
以P1429 平面最近点对(加强版)为例,贴个板子。
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define initrand srand((unsigned)time(0))
#define random(x) ((LL)rand() * rand() % (x))
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
inline int read() {int x = 0, f = 1;char ch = getchar();while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }return x * f;
}const int N = 200010;
int n, K, cur;
double ans = 2e18;struct KD {int lc, rc;double v[2], L[2], U[2];bool operator < (const KD &t) const {return v[K] < t.v[K];}
} tr[N]; void pushup(int u) {rep(i, 0, 1) {tr[u].L[i] = tr[u].U[i] = tr[u].v[i];if (tr[u].lc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);} if (tr[u].rc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);}}
}int build(int l, int r, int k) {if (l > r) return 0;int mid = l + r >> 1;K = k; nth_element(tr + l, tr + mid, tr + r + 1);tr[mid].lc = build(l, mid - 1, k ^ 1);tr[mid].rc = build(mid + 1, r, k ^ 1);pushup(mid);return mid;
}inline double sq(double x) {return x * x;
}inline double dis1(int x) {double res = 0;rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);return res;
}inline double dis2(int x) {if (!x) return 2e18;double res = 0;rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));return res;
}void query(int u) {if (!u) return;if (u != cur) ans = min(ans, dis1(u));double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);if (d1 < d2) {if (d1 < ans) query(tr[u].lc);if (d2 < ans) query(tr[u].rc);} else {if (d2 < ans) query(tr[u].rc);if (d1 < ans) query(tr[u].lc);}
}int main() {n = read();rep(i, 1, n) scanf("%lf%lf", &tr[i].v[0], &tr[i].v[1]);int root = build(1, n, 0);for (cur = 1; cur <= n; cur ++) query(root); printf("%.4lf\n", sqrt(ans));return 0;
}
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define initrand srand((unsigned)time(0))
#define random(x) ((LL)rand() * rand() % (x))
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
inline int read() {int x = 0, f = 1;char ch = getchar();while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }return x * f;
}const int N = 200010;
const double A = 0.7;
int n, K, root, cur;
double ans = 2e18;int idx;
struct KD {int lc, rc, sz;double v[2], L[2], U[2];bool operator < (const KD &t) const {return v[K] < t.v[K];}
} tr[N]; void pushup(int u) {tr[u].sz = tr[tr[u].lc].sz + tr[tr[u].rc].sz;rep(i, 0, 1) {tr[u].L[i] = tr[u].U[i] = tr[u].v[i];if (tr[u].lc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);} if (tr[u].rc) {tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);}}
}int g[N], cnt;bool cmp(int a, int b) {return tr[a].v[K] < tr[b].v[K];
}int rebuild(int l, int r, int k) {if (l > r) return 0;int mid = l + r >> 1;K = k; nth_element(g + l, g + mid, g + r + 1, cmp);tr[g[mid]].lc = rebuild(l, mid - 1, k ^ 1);tr[g[mid]].rc = rebuild(mid + 1, r, k ^ 1);pushup(g[mid]);return g[mid];
}void dfs(int u) {if (!u) return;g[++ cnt] = u;dfs(tr[u].lc);dfs(tr[u].rc);
}void check(int &u, int k) {if (tr[u].sz * A < max(tr[tr[u].lc].sz, tr[tr[u].rc].sz)) cnt = 0, dfs(u), u = rebuild(1, cnt, k);
}void insert(int &u, int k) {if (!u) { u = cur; pushup(u); return; }insert(tr[cur].v[k] <= tr[u].v[k] ? tr[u].lc : tr[u].rc, k ^ 1);pushup(u);check(u, k);
}inline double sq(double x) {return x * x;
}inline double dis1(int x) {double res = 0;rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);return res;
}inline double dis2(int x) {if (!x) return 2e18;double res = 0;rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));return res;
}void query(int u) {if (!u) return;if (u != cur) ans = min(ans, dis1(u));double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);if (d1 < d2) {if (d1 < ans) query(tr[u].lc);if (d2 < ans) query(tr[u].rc);} else {if (d2 < ans) query(tr[u].rc);if (d1 < ans) query(tr[u].lc);}
}int main() {n = read();rep(i, 1, n) scanf("%lf%lf", &tr[i].v[0], &tr[i].v[1]);for (cur = 1; cur <= n; cur ++) insert(root, 0);for (cur = 1; cur <= n; cur ++) query(root); printf("%.4lf\n", sqrt(ans));return 0;
}
模板题
P2479 [SDOI2010] 捉迷藏
距离计算更改为曼哈顿距离,然后再增加统计一下最大值即可。
直接建树会比动态插入快非常多。
代码
P4148 简单题
操作 \(1\) 可以看作动态插点,操作 \(2\) 可以直接用类似线段树查询的方式,对于每个结点分类讨论三种(我们将一个 \(KDT\) 结点表示范围看作矩形):
- 该节点所对矩形完全不包含于询问矩形
- 该节点所对矩形完全包含于询问矩形
- 部分包含
对于 \(1\) 和 \(2\) 来说是简单的,对于 \(3\),我们可以判断一下当前节点的 \(v\) 是否在矩形内,然后递归左右子树最后加上根的贡献。
时间复杂度被证明是根号的。
代码
困难一点的题
P5471 [NOI2019] 弹跳
考虑 \(KDT\) 优化建图后跑 \(dijkstra\),我们把 \(1\sim n\) 记为原始的点(下文称为实点),\(n + 1\sim 2n\) 记为 \(KDT\) 建出来的点(虚点)。显然一个实点 \(u\) 对应虚点 \(u + n\)。对于一个虚点 \(u\),它显然可以向 \(u - n\) 连边,也可以向它在 \(KDT\) 中的左右儿子连边;对于一个实点,我们遍历从该点出发的弹跳装置,分类讨论(以下点均为实点所对虚点):
- 一个虚点 \(x\) 对应的矩阵完全包含在弹跳装置内,直接将 \(u\) 向 \(x\) 连边
- 完全不包含,直接返回
- 部分包含的话,如果该虚点的坐标能够包含于弹跳装置内就让 \(u\) 向 \(x - n\) 连边
然而如果真的连边的话会被卡爆,我们考虑边做 \(dij\) 边跑上面过程,每次不建边直接用需要的点进行更新操作。
于是这道题就做完了,代码稍微有点难写。
代码