题目
思路
为每个宗教维护一个线段数,查询时,树剖时在对应宗教上查询区间即可。
使用动态开点线段树,每次最多新建 \(\log n\) 个节点,不会 MLE。
代码
#include <bits/stdc++.h>#define range 1, 100000using namespace std;const int N = 100010;struct edge {int to, next;
} e[N * 2];int head[N], idx = 1;void add(int u, int v) {idx++;e[idx].to = v;e[idx].next = head[u];head[u] = idx;
}int dep[N], fa[N], sz[N], son[N];void dfs1(int u, int f) {dep[u] = dep[f] + 1, sz[u] = 1, fa[u] = f;for (int i = head[u]; i; i = e[i].next) {int to = e[i].to;if (to == f) continue;dfs1(to, u);sz[u] += sz[to];if (sz[to] > sz[son[u]]) son[u] = to;}
}int top[N], dfn[N], rk[N], dfn_cnt;void dfs2(int u, int t) {dfn[u] = ++dfn_cnt, rk[dfn_cnt] = u, top[u] = t;if (son[u]) dfs2(son[u], t);for (int i = head[u]; i; i = e[i].next) {int to = e[i].to;if (to == son[u] || to == fa[u]) continue;dfs2(to, to);}
}struct node {int l, r;int sum;int max;
} tr[N * 31];int root[N], seg_idx;void pushup(int u) {tr[u].sum = tr[tr[u].l].sum + tr[tr[u].r].sum;tr[u].max = max(tr[tr[u].l].max, tr[tr[u].r].max);
}void modify(int& u, int l, int r, int x, int v) {if (!u) u = ++seg_idx;if (l == r) {tr[u].sum = tr[u].max = v;return;}int mid = l + r >> 1;if (x <= mid) modify(tr[u].l, l, mid, x, v);else modify(tr[u].r, mid + 1, r, x, v);pushup(u);
}pair<int, int> query(int u, int l, int r, int pl, int pr) {if (!u) return {0, 0};if (pl <= l && r <= pr) return {tr[u].max, tr[u].sum};int mid = l + r >> 1;if (pr <= mid) return query(tr[u].l, l, mid, pl, pr);else if (pl > mid) return query(tr[u].r, mid + 1, r, pl, pr);else {auto q1 = query(tr[u].l, l, mid, pl, pr);auto q2 = query(tr[u].r, mid + 1, r, pl, pr);return {max(q1.first, q2.first), q1.second + q2.second};}
}int n, q;
int w[N], c[N];pair<int, int> ask(int u, int v) {pair<int, int> ans = {0, 0};int rt = root[c[u]];while (top[u] != top[v]) {if (dep[top[u]] < dep[top[v]]) swap(u, v);auto q = query(rt, range, dfn[top[u]], dfn[u]);ans.first = max(ans.first, q.first);ans.second += q.second;u = fa[top[u]];}if (dep[u] > dep[v]) swap(u, v);auto q = query(rt, range, dfn[u], dfn[v]);ans.first = max(ans.first, q.first);ans.second += q.second;return ans;
}int main() {ios::sync_with_stdio(false);cin.tie(nullptr);cin >> n >> q;for (int i = 1; i <= n; i++) cin >> w[i] >> c[i];for (int i = 1; i < n; i++) {int u, v;cin >> u >> v;add(u, v), add(v, u);}dfs1(1, 0);dfs2(1, 1);for (int i = 1; i <= n; i++) modify(root[c[i]], range, dfn[i], w[i]);string opt;int x, y;while (q--) {cin >> opt >> x >> y;if (opt == "CC") {modify(root[c[x]], range, dfn[x], 0);c[x] = y;modify(root[c[x]], range, dfn[x], w[x]);}else if (opt == "CW") {w[x] = y;modify(root[c[x]], range, dfn[x], w[x]);}else if (opt == "QS") cout << ask(x, y).second << '\n';else cout << ask(x, y).first << '\n';}return 0;
}