题目链接:https://atcoder.jp/contests/abc133/tasks/abc133_f
题目大意:
有一棵树,顶点编号从 \(1\) 到 \(N\)。
这棵树中第 \(i\) 条边连接着顶点 \(a_i\) 和顶点 \(b_i\),其颜色和长度分别为 \(c_i\) 和 \(d_i\)。
这里每条边的颜色用介于 \(1\) 和 \(N-1\)(包括边界值)之间的整数表示。相同的整数代表相同的颜色,不同的整数代表不同的颜色。
回答以下 \(Q\) 个查询:
查询 \(j\) (\(1 \leq j \leq Q\)): 假设颜色为 \(x_j\) 的边的长度都改变为 \(y_j\),求顶点 \(u_j\) 和顶点 \(v_j\) 之间的距离。(边的长度的改变不会影响后续的查询。)
解题思路完全参考自 Minecraft万岁 大佬的博客:https://www.luogu.com.cn/article/aw4dp6vd
我写代码的时候碰到一个比较脑抽的问题是:习惯用 d 表示深度,但是这里 edge 里也有一个 d,调了半天,然后把深度改成 depth 表示了囧
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5, maxm = 2e6 + 5;int rt[maxn], idx, ls[maxm], rs[maxm];int tcnt[maxm], tsum[maxm];void push_up(int u) {tcnt[u] = tcnt[ls[u]] + tcnt[rs[u]];tsum[u] = tsum[ls[u]] + tsum[rs[u]];
}// 多了一条颜色为 c 长度为 d 的边
void add(int c, int d, int l, int r, int u, int uu) {tcnt[u] = tcnt[uu];tsum[u] = tsum[uu];ls[u] = ls[uu];rs[u] = rs[uu];if (l == r) {tcnt[u]++;tsum[u] += d;return;}int mid = (l + r) / 2;if (c <= mid) {ls[u] = ++idx;add(c, d, l, mid, ls[u], ls[uu]);}else {rs[u] = ++idx;add(c, d, mid+1, r, rs[u], rs[uu]);}push_up(u);
}pair<int, int> query(int c, int l, int r, int u) {if (!u) return {0, 0};if (l == r) {return { tcnt[u], tsum[u] };}int mid = (l + r) / 2;return (c <= mid) ? query(c, l, mid, ls[u]) : query(c, mid+1, r, rs[u]);
}int fa[maxn][17], dis[maxn][17], dep[maxn];
int n, Q;
struct Edge { int v, c, d; };
vector<Edge> g[maxn];void dfs(int u, int p, int depth) {fa[u][0] = p;dep[u] = depth;for (auto e : g[u]) {int v = e.v, c = e.c, d = e.d;if (v == p) continue;rt[v] = ++idx;add(c, d, 1, n-1, rt[v], rt[u]);dis[v][0] = d;dfs(v, u, depth+1);}
}void check_dfs(int u, int p) {for (auto e : g[u]) {int v = e.v;if (v != p)assert(dep[v] == dep[u] + 1),check_dfs(v, u);}
}int lca(int x, int y) {if (dep[x] < dep[y]) swap(x, y);for (int i = 16; i >= 0; i--)if (dep[ fa[x][i] ] >= dep[y])x = fa[x][i];if (x == y) return x;for (int i = 16; i >= 0; i--)if (fa[x][i] != fa[y][i])x = fa[x][i], y = fa[y][i];return fa[x][0];
}// 计算从节点 x 到它的祖先节点 z 的所有边的长度总和
int get_dis(int x, int z) {int res = 0;for (int i = 16; i >= 0; i--) {if (dep[ fa[x][i] ] >= dep[z]) {res += dis[x][i];x = fa[x][i];}}return res;
}int cal(int c, int w, int x, int y) {int z = lca(x, y);int cnt = 0, sum = 0;auto pi = query(c, 1, n-1, rt[x]);cnt += pi.first;sum += pi.second;pi = query(c, 1, n-1, rt[y]);cnt += pi.first;sum += pi.second;pi = query(c, 1, n-1, rt[z]);cnt -= 2 * pi.first;sum -= 2 * pi.second;int dis1 = get_dis(x, z), dis2 = get_dis(y, z);return dis1 + dis2 - sum + cnt * w;
}int main() {scanf("%d%d", &n, &Q);for (int i = 1; i < n; i++) {int a, b, c, d;scanf("%d%d%d%d", &a, &b, &c, &d);g[a].push_back({ b, c, d });g[b].push_back({ a, c, d });}rt[1] = ++idx;dfs(1, 0, 1);check_dfs(1, 1);for (int j = 1; j <= 16; j++) {for (int i = 1; i <= n; i++) {fa[i][j] = fa[ fa[i][j-1] ][j-1];dis[i][j] = dis[i][j-1] + dis[ fa[i][j-1] ][j-1];}}while (Q--) {int c, w, x, y;scanf("%d%d%d%d", &c, &w, &x, &y);printf("%d\n", cal(c, w, x, y));}return 0;
}