三棵树就很毒瘤了,我们一棵一棵看。
关于第一棵树的路径,经典解法就是点分治和边分治,考虑哪种更加简单。
设 \(dis1/2/3(x)\) 表示 \(x\) 在第 \(1/2/3\) 棵树中的深度(第一棵树的深度当然是点到重心或重边的距离),\(lca2/3(x,y)\) 表示在第 \(2/3\) 棵树中的最近公共祖先。不管怎么说,问题转化为求解:
\[\min_{1\le x<y\le n}dis1(x)+dis1(y)+dis2(x)+dis2(y)+dis3(x)+dis3(y)-2\times dis2(lca2(x,y))-2\times dis3(lca3(x,y))
\]
假如是边分治,还需要加上中间边的长度。设 \(w=dis1(x)+dis2(x)+dis3(x)\),则简化为:
\[\min_{1\le x<y\le n}w(x)+w(y)-2\times dis2(lca2(x,y))-2\times dis3(lca3(x,y))
\]
处理第二棵树,既然涉及到 \(lca\),自然想到虚树,问题转化为求解:
\[\min_{1\le x<y\le n}w(x)+w(y)-2\times dis3(lca3(x,y))
\]
然后发现这就是在求第三棵树的点集直径(很非典型就是了),直接 \(lca\) 即可。第二棵树中只需要维护删去重点或重边后每棵子树的节点在第三棵树中的直径即可。
那么此时用点分还是边分就一目了然了。边分子树个数只有两只,而点分很不好说,所以边分会更加简单(当然点分实际上也可以)。
时间复杂度 \(O(n\log n)-O(n\log^2n)\),取决于建立虚树和 \(lca\) 的时间复杂度。这里用了 \(lca\) 建立虚树和欧拉序求解 \(lca\)。时间复杂度是常数超小的 \(O(n\log^2n)\)。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5;
int n,ans;
struct edge{int to,cs;};
namespace tr3{vector<edge>g[N];int ij,fs[N];int dep[N],dis[N],st[N][20];void get3(){for(int i=1,x,y,z;i<n;i++){cin>>x>>y>>z;g[x].push_back({y,z});g[y].push_back({x,z});}}void dfs(int x,int fa){dep[x]=dep[fa]+1,st[fs[x]=++ij][0]=x;for(auto y:g[x]) if(y.to!=fa)dis[y.to]=dis[x]+y.cs,dfs(y.to,x),st[++ij][0]=x;}int minn(int x,int y){return dep[x]>dep[y]?y:x;}void ST(){for(int i=0;i<19;i++)for(int j=1;j<=ij-(1<<(i+1))+1;j++)st[j][i+1]=minn(st[j][i],st[j+(1<<i)][i]);}int rmq(int l,int r){int k=log2(r-l+1),x=r-(1<<k)+1;return minn(st[l][k],st[x][k]);}int lca(int x,int y){if(fs[x]>fs[y]) swap(x,y);return rmq(fs[x],fs[y]);}void init(){dfs(1,0),ST();}
}namespace tr2{vector<edge>g[N];vector<int>ve[N];int dep[N],dis[N],st[N][20],fs[N],idx[N];int a[N],dfn[N],w[N],f[N][2][2],m,k,idc,ij;int dis3(int x,int y){if(!x||!y) return -1e18;return w[x]+w[y]-tr3::dis[tr3::lca(x,y)]*2;}int cmp(int x,int y){return dfn[x]<dfn[y];}void get2(){for(int i=1,x,y,z;i<n;i++){cin>>x>>y>>z;g[x].push_back({y,z});g[y].push_back({x,z});}for(int i=1;i<=n;i++) idx[i]=-1;}void dfs(int x,int fa){st[fs[x]=++ij][0]=x,dfn[x]=++idc,dep[x]=dep[fa]+1;for(auto y:g[x]) if(y.to!=fa)dis[y.to]=dis[x]+y.cs,dfs(y.to,x),st[++ij][0]=x;}int minn(int x,int y){return dep[x]>dep[y]?y:x;}void ST(){for(int i=0;i<19;i++)for(int j=1;j<=ij-(1<<(i+1))+1;j++)st[j][i+1]=minn(st[j][i],st[j+(1<<i)][i]);}int rmq(int l,int r){int k=log2(r-l+1),x=r-(1<<k)+1;return minn(st[l][k],st[x][k]);}int lca(int x,int y){if(fs[x]>fs[y]) swap(x,y);return rmq(fs[x],fs[y]);}void init(){dfs(1,0),ST();}void merge(int x,int y,int id){if(!f[x][id][0]){f[x][id][0]=f[y][id][0];f[x][id][1]=f[y][id][1];return;}int d=0,fa=f[x][id][0];int fb=f[x][id][1],mx=dis3(fa,fb);if((d=dis3(f[y][id][0],f[y][id][1]))>mx)mx=d,fa=f[y][id][0],fb=f[y][id][1];for(int i=0;i<2;i++) for(int j=0;j<2;j++)if((d=dis3(f[x][id][i],f[y][id][j]))>mx)mx=d,fa=f[x][id][i],fb=f[y][id][j];f[x][id][0]=fa,f[x][id][1]=fb;}void getans(int x,int fa,int cc){if(idx[x]<2) f[x][idx[x]][0]=x;int mx=-1e18;for(auto y:ve[x]){if(y==fa) continue;getans(y,x,cc);for(int i=0;i<2;i++) for(int j=0;j<2;j++){mx=max(mx,dis3(f[x][0][i],f[y][1][j]));mx=max(mx,dis3(f[x][1][i],f[y][0][j]));}merge(x,y,0),merge(x,y,1);}ans=max(ans,mx+cc-dis[x]*2);}void build(vector<int>c0,vector<int>c1,int cc){m=0;for(int i=1;i<=k;i++){ve[a[i]].clear(),idx[a[i]]=-1;f[a[i]][0][0]=f[a[i]][0][1]=0;f[a[i]][1][0]=f[a[i]][1][1]=0,a[i]=0;}for(auto y:c0) a[++m]=y,idx[y]=0;for(auto y:c1) a[++m]=y,idx[y]=1;sort(a+1,a+m+1,cmp),k=m;for(int i=1;i<m;i++){int lc=lca(a[i],a[i+1]);if(idx[lc]<0) idx[a[++k]=lc]=2;}sort(a+1,a+k+1,cmp);for(int i=2;i<=k;i++){int lc=lca(a[i],a[i-1]);ve[a[i]].push_back(lc);ve[lc].push_back(a[i]);}getans(lca(a[1],a[2]),0,cc);}
}namespace tr1{vector<edge>g[N],ve[N];vector<int>cl[2];int tot,cc;unordered_map<int,int>mp[N];int rtx,rty,ls,dis[N],sz[N];void getve(){for(int i=1,x,y,z;i<n;i++){cin>>x>>y>>z;ve[x].push_back({y,z});ve[y].push_back({x,z});}tot=n;}void get1(int x,int fa){int lst=0;for(auto y:ve[x]){if(y.to==fa) continue;if(!lst){g[x].push_back({++tot,0});g[lst=tot].push_back({x,0});}else{g[lst].push_back({++tot,0});g[tot].push_back({lst,0}),lst=tot;}g[tot].push_back({y.to,y.cs});g[y.to].push_back({tot,y.cs}),get1(y.to,x);}}void getrt(int x,int fa,int sm){sz[x]=1,dis[x]=0;for(auto y:g[x])if(y.to!=fa&&!mp[x][y.to])getrt(y.to,x,sm),sz[x]+=sz[y.to],sz[y.to]=1;if(ls>max(sz[x],sm-sz[x])&&fa)rtx=x,rty=fa,ls=max(sz[x],sm-sz[x]);}void getsz(int x,int fa){if(x<=n) tr2::w[x]=dis[x]+tr2::dis[x]+tr3::dis[x];for(auto y:g[x]) if(y.to!=fa&&!mp[x][y.to])dis[y.to]=dis[x]+y.cs,getsz(y.to,x),sz[x]+=sz[y.to];}void adc(int x,int fa,int id){if(x<=n) cl[id].push_back(x);for(auto y:g[x])if(y.to!=fa&&!mp[x][y.to]) adc(y.to,x,id);}void solve(int x,int sm){if(sm==1) return;ls=1e18,getrt(x,0,sm);sz[x]=mp[rtx][rty]=mp[rty][rtx]=1;getsz(rtx,rty),getsz(rty,rtx);cl[0].clear(),cl[1].clear();adc(rtx,rty,0),adc(rty,rtx,1);for(auto y:g[rtx]) if(y.to==rty){cc=y.cs;break;}tr2::build(cl[0],cl[1],cc);int nw=rty;solve(rtx,sz[rtx]),solve(nw,sz[nw]);}
}signed main(){ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);cin>>n,tr1::getve();tr3::get3(),tr2::get2();tr1::get1(1,0),tr2::init();tr3::init(),tr1::solve(1,n*2-1);return cout<<ans,0;
}