一. 题面:点这里
二. 思路:
首先有一个比较常用的 trick 就是对于任意一颗带权树,对于题目中所求的两点对之间的距离和有如下公式:
\[Sum_{dis}=\sum_{e\in E}w(e)\times size_v\times(n-size_v)
\]
这个公式应当是容易理解的,那么最后的贡献就可以这么算,但是现在有一个连边的问题。这和最优化问题有关,可以 dp 吗?好像不太现实,那我们考虑贪心。一步步来,首先你肯定会有一个疑问就是如果要连边,应当连到这个连通块的哪一个点呢。因为连通块本身的点对之间距离是确定的,考虑其他连通块连过来时造成的贡献,显然我们应当让这个连接点到这个连通块其他部分的距离和最小(很好证明)。关于距离和最小这个问题是经典的,可以其中一个解法是可以考虑换根dp 。解决这个问题之后,我们就要考虑我们分别连接哪些联通块是最优的呢。让我们回到算贡献的式子,那么我们应当是希望这个图是菊花图(假设我们把所有连通块内部的点缩成一个点之后),这样可以最小化每一个形如 \(x(n-x)\) 的乘积式。问题是菊花图的中心是什么呢。我们猜想是连通块大小最大的那一个,证明考虑反证法,假设调换最大块和某一块的值,一定不会使答案变小,手搓一个二次函数图像之后就会发现这个问题等价于证明:任意一个正整数 \(t\) 表示不是最大块的大小都有:
\[t\leq n-\max\{size\}
\]
那么这是显然成立的(请读者自行思考,这是简单的。)
所以最后只需要以最大块向外连边即可,至于连边的顺序,考虑将边权大的连向乘积贡献小的连通块即可。
三. Code:
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<vector>
#define ll __int128
inline int read() {int x=0,f=1;char ch=getchar();while(ch > '9' || ch < '0'){if(ch == '-'){f = -1;}ch = getchar();}while(ch >= '0'&&ch <= '9'){x = x * 10 + ch - 48; ch = getchar();}return x * f;
}inline void write(ll x)
{if (x < 0) putchar('-'), x = -x;if (x > 9) write(x / 10);putchar(x % 10 + 48);return;
}
const int N = 1e6 + 5,MOD = 1e9 + 7;
int siz[N],a[N];
std::vector<std::pair<int,int> > G[N];ll dfs(int u,int fa)
{siz[u] = 1;ll res = 0;for(auto now : G[u]){int v = now.first,w = now.second;if(v == fa) continue;res += dfs(v,u) + 1ll * siz[v] * w;siz[u] += siz[v];}return res;
}ll tmp_res = 0,mn = 0;
void chgrt_dp(int u,int fa,ll mn_now,int now_siz)
{mn = std::min(mn,mn_now);tmp_res += mn_now;for(auto now : G[u]){int v = now.first,w = now.second;if(v == fa) continue;chgrt_dp(v,u,mn_now + 1ll * (now_siz - siz[v]) * w - 1ll * siz[v] * w,now_siz);}
}bool cmp(int a,int b){return a > b;}
int main()
{int n,m;n = read(),m = read();for(int i = 1;i <= m;++i){int u,v,w;u = read(),v = read(),w = read();G[u].push_back({v,w});G[v].push_back({u,w});}for(int i = 1;i <= n - m - 1;++i) a[i] = read();std::sort(a + 1,a + n - m);std::vector<int> p;ll ans = 0;for(int i = 1;i <= n;++i){if(!siz[i]){mn = dfs(i,0);tmp_res = 0;chgrt_dp(i,0,mn,siz[i]);ans += tmp_res / 2 + 1ll * (n - siz[i]) * mn;p.push_back(siz[i]);}}std::sort(p.begin(),p.end(),cmp);for(int i = 1;i <= n - m - 1;++i) ans += 1ll * (n - p[i]) * p[i] * a[i];write(ans % MOD);return 0;
}