一起看很美的日落!
题目描述
牛可乐有一棵由 $n$ 个结点构成的树,第 $i$ 个节点的权值为 $a_i$。
我们定义一个连通块 $\mathbb{V}$ 的权值为:
- 当前连通块中两两结点的权值异或和,即 $\sum\limits_{i,j \in \mathbb{V}}{a_i \oplus a_j}$;
你需要计算全部连通块的权值之和。由于答案可能很大,请将答案对 $(10^9 +7)$ 取模后输出。
此题中的连通块定义为:对于树上的任意一个点集 $\mathbb{S}$,如果 $\mathbb{S}$ 中的任意两点 $u,v$ 之间存在一条路径,且路径上的所有点都在 $\mathbb{S}$ 中,则称 $\mathbb{S}$ 是一个连通块。
输入描述:
第一行输入一个整数 $n \, (1 \leq n \leq 10^5 )$ 代表树上的节点数量。
第二行输入 $n$ 个整数 $a_1,a_2, \ldots ,a_n \, (1 \leq a_i \leq 10^9)$ 代表每个节点的权值。
此后 $n−1$ 行,第 $i$ 行输入两个整数 $u_i,v_i \, (1 \leq u_i,v_i \leq n,u_i \ne v_i)$ 代表树上第 $i$ 条边连接节点 $u_i$ 和 $v_i$。
输出描述:
输出一个整数,代表全部连通块的权值和。答案可能很大,请将答案对 $(10^9+7)$ 取模后输出。
示例1
输入
3
5 2 1
1 2
1 3
输出
50
说明
在这个样例中,一共有 $6$ 个连通块,每一个连通块的权值依次为:
- $\{1\}$,记为 $\mathbb{V}_1$,权值为 $\sum\limits_{i,j \in \Bbb V_1}(a_i \oplus a_j)=a_1\oplus a_1=0$;
- $\{1,2\}$,记为 $\mathbb{V}_2$,权值为 $\sum\limits_{i,j \in \Bbb V_2}(a_i \oplus a_j)=(a_1\oplus a_1)+(a_1\oplus a_2)+(a_2\oplus a_1)+(a_2\oplus a_2)=14$;
- $\{1,3\}$,记为 $\mathbb{V}_3$,权值为 $\sum\limits_{i,j \in \Bbb V_3}(a_i \oplus a_j)=(a_1\oplus a_1)+(a_1\oplus a_3)+(a_3\oplus a_1)+(a_3\oplus a_3)=8$;
- $\{1,2,3\}$,记为 $\mathbb{V}_4$,权值为 $\sum\limits_{i,j \in \Bbb V_4}(a_i \oplus a_j)=28$;
- $\{2\}$,记为 $\mathbb{V}_5$,权值为 $\sum\limits_{i,j \in \Bbb V_5}(a_i \oplus a_j)=a_2\oplus a_2=0$;
- $\{3\}$,记为 $\mathbb{V}_6$,权值为 $\sum\limits_{i,j \in \Bbb V_6}(a_i \oplus a_j)=a_3\oplus a_3=0$;
示例2
输入
4
5 6 3 1
1 2
1 3
2 4
输出
142
解题思路
大概想了下发现很复杂要维护很多东西,以为错了就直接看题解,没想到还真这么做。
假设以节点 $1$ 作为整颗树的根。题目中的连通块其实就是以任意节点为根的子树,因此我们可以按子树的根把所有的连通块(子树)分成 $n$ 类,然后分别求每一类连通块的结果最后累加即可。由于涉及到位运算,我们可以对节点的权值进行拆位分别处理,这样每个节点的权值只有 $0$ 和 $1$ 两种情况,最后将算出的结果乘以相应的 $2$ 的次幂,就是该位对答案的贡献。
现在有一个集合 $s$,元素只有 $0$ 和 $1$,数量分别记为 $c(s,0)$ 和 $c(s,1)$。根据异或的性质,该集合中任意两个元素的异或值的和就是 $c(s,0) \times c(s,1)$,记为 $f(s)$。现在有集合 $t$ 要与 $s$ 进行合并,$0$ 和 $1$ 的数量就变成了 $c(s,0) + c(t,0)$ 和 $c(s,1) + c(t,1)$,异或值的和就是
\begin{align*}
f(s \cup t) &= (c(s,0) + c(t,0)) \times (c(s,1) + c(t,1)) \\
&= c(s,0) \times c(s,1) + c(t,0) \times c(t,1) + c(s,0) \times c(t,1) + c(s,1) \times c(t,0) \\
&= f(s) + f(t) + c(s,0) \times c(t,1) + c(s,1) \times c(t,0)
\end{align*}
假设有 $g(s)$ 个集合 $s_i$ 要与 $g(t)$ 个集合 $t_j$ 进行合并,那么所有合并结果的异或值的和就是
\begin{align*}
\sum\limits_i{\sum\limits_j{f(s_i \cup t_j)}} &= \sum\limits_i{\sum\limits_j{(c(s_i,0) + c(t_j,0)) \times (c(s_i,1) + c(t_j,1))}} \\
&= \sum\limits_j{\sum\limits_i{f(s_i)}} + \sum\limits_i{\sum\limits_j{f(t_j)}} + \sum\limits_i{c(s_i,0)} \times \sum\limits_j{c(t_j,1)} + \sum\limits_i{c(s_i,1)} \times \sum\limits_j{c(t_j,0)} \\
&= g(t)\sum\limits_i{f(s_i)} + g(s)\sum\limits_j{f(t_j)} + \sum\limits_i{c(s_i,0)} \times \sum\limits_j{c(t_j,1)} + \sum\limits_i{c(s_i,1)} \times \sum\limits_j{c(t_j,0)} \\
\end{align*}
同理合并后的集合数量就是 $g(s) \times g(t)$,所有集合 $0$ 的数量就是 $g(t)\sum\limits_i{c(s_i,0)} +g(s)\sum\limits_j{c(t_j,0)}$,所有集合 $1$ 的数量就是 $g(t)\sum\limits_i{c(s_i,1)} +g(s)\sum\limits_j{c(t_j,1)}$。
假设我们现在处理的是第 $k$ 位,用 $a_{u,k}$ 表示二进制下节点 $u$ 权值第 $k$ 位的值。假设 $u$ 有 $m_u$ 个子节点,记为 $v_1, v_2, \ldots, v_{m_u}$。那么以 $u$ 为根的子树就可以通过以 $v_i$ 为根的所有子树进行任意组合得到,可以用 dp 求出以 $u$ 为根的所有子树的任意两个节点异或值的和。
定义 $f(u,i)$ 表示以 $v_1, v_2, \ldots, v_i$ 为根的子树组合得到的所有以 $u$ 为根的子树的任意两个节点异或值的和,$g(u,i)$ 表示以 $v_1, v_2, \ldots, v_i$ 为根的子树组合得到的所有以 $u$ 为根的子树的数量,$c(u,0,i)$ 和 $c(u,1,i)$ 分别表示以 $v_1, v_2, \ldots, v_i$ 为根的子树组合得到的所有以 $u$ 为根的子树的权值为 $0$ 和 $1$ 的节点数量。根据以 $v_i$ 为根的子树进行状态划分,可以把以 $v_i$ 为根的子树看成上述若干个集合 $t_j$,把以 $v_1, v_2, \ldots, v_{i-1}$ 为根的子树组合得到的所有以 $u$ 为根的子树看成若干个集合 $s_i$,$s_i$ 与 $t_j$ 两两进行合并。
初始只有一个节点 $u$,此时有 $f(u,0) = a_{u,k}$,$g(u,0) = 1$,$c(u,a_{u,k},0) = 1$,$c(u,\neg a_{u,k},0) = 0$。参考上面的式子,状态转移方程就是
\begin{cases}
\begin{align*}
f(u,i) &= f(u,i-1) + g(v_i, m_{v_i}) \cdot f(u,i-1) + g(u,i-1) \cdot f(v_i,m_{v_i}) + c(u,0,i-1) \cdot c(v_i,1,m_{v_i}) + c(u,1,i-1) \cdot c(v_i,0,m_{v_i}) \\
c(u,0,i) &= c(u,0,i-1) + g(v_i,m_{v_i}) \cdot c(u,0,i-1) + g(u, i-1) \cdot c(u,0,i-1) \\
c(u,1,i) &= c(u,1,i-1) + g(v_i,m_{v_i}) \cdot c(u,1,i-1) + g(u, i-1) \cdot c(u,1,i-1) \\
g(u,i) &= g(u,i-1) + g(u,i-1) \cdot g(v_i,m_{v_i})
\end{align*}
\end{cases}
在代码实现中省略了所有状态的最后一维。
那么第 $k$ 位对答案的贡献就是 $2^k\sum\limits_{i=1}^{n}{f(i,m_i)}$。最后要求的答案就是 $2 \cdot \sum\limits_{k=0}^{29}2^k\sum\limits_{i=1}^{n}{f(i,m_i)}$,多乘一个 $2$ 是因为上述考虑的是两个节点的组合,而题目要求考虑两个节点的排列。
AC 代码如下,时间复杂度为 $O(n \log{A})$:
#include <bits/stdc++.h>
using namespace std;typedef long long LL;const int N = 1e5 + 5, M = N * 2, mod = 1e9 + 7;int a[N];
int h[N], e[M], ne[M], idx;
int f[N], g[N], c[N][2];void add(int u, int v) {e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}void dfs(int u, int p, int k) {f[u] = 0, g[u] = 1, c[u][a[u] >> k & 1] = 1, c[u][~a[u] >> k & 1] = 0;for (int i = h[u]; i != -1; i = ne[i]) {int v = e[i];if (v == p) continue;dfs(v, u, k);f[u] = (f[u] + 1ll * f[u] * g[v] + 1ll * f[v] * g[u] + 1ll * c[u][0] * c[v][1] + 1ll * c[u][1] * c[v][0]) % mod;c[u][0] = (c[u][0] + 1ll * c[u][0] * g[v] + 1ll * c[v][0] * g[u]) % mod;c[u][1] = (c[u][1] + 1ll * c[u][1] * g[v] + 1ll * c[v][1] * g[u]) % mod;g[u] = (g[u] + 1ll * g[u] * g[v]) % mod;}
}int main() {int n;cin >> n;for (int i = 1; i <= n; i++) {cin >> a[i];}memset(h, -1, sizeof(h));for (int i = 0; i < n - 1; i++) {int u, v;cin >> u >> v;add(u, v), add(v, u);}int ret = 0;for (int i = 0; i < 30; i++) {dfs(1, 0, i);for (int j = 1; j <= n; j++) {ret = (ret + f[j] * (1ll << i)) % mod;}}cout << ret * 2 % mod;return 0;
}
参考资料
2025牛客寒假算法基础集训营2 出题人题解:https://blog.nowcoder.net/n/906fd00ff386438b9d63013a3760e73a