题意
请自行移步洛谷
分析
每次询问有两种情况:
- a 为 (a,b,c) 中深度第二大的节点,即 b 在 a 到根的路径上 (不含 a )。
对答案的贡献即为 $ ( siz_a -1 ) * \min ( dep_a -1 , k) $ 。
- a 为 (a,b,c) 中深度最小的节点,即 b 是 a 的儿子。
记 $ son_{a,d} $ 为 a 点的 d 级儿子,此处的 d 级儿子是指在 a 子树内并且深度差等于 d 的点。
对答案的贡献即为 $$ \sum_{d=1}^k \sum_{u \in son_{a,d}} siz_u $$
所以得出,维护第二种情况较为困难。但是这是可以解决的,将子树大小记在对应深度的线段树的节点上,遍历时向上线段树合并即可。
codes
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+100;
typedef long long ll;
inline int read()
{register char c=getchar();int x=0;while(!isdigit(c))c=getchar();while(isdigit(c)){x=(x<<1)+(x<<3)+c-48;c=getchar();}return x;
}struct edge{int y,n;}e[N<<1];struct ques
{int id;int x;
}g[N];int n,q,head[N],cnt;
int siz[N],dep[N],deg[N];
ll ans[N];
vector<ques>h[N];struct segtree
{struct node{ll val;int lc,rc;}s[N<<5];int tot,root[N<<5];void pushup(int i){s[i].val=s[ s[i].lc ].val+s[ s[i].rc ].val;}void upd(int &i,int l,int r,int x,ll z){if(!i)i=++tot;s[i].val+=z;if(l==r)return ;int mid=(l+r)>>1;if(x<=mid)upd(s[i].lc,l,mid,x,z);else upd(s[i].rc,mid+1,r,x,z);pushup(i);}void merg(int &i,int &j,int l,int r)// i <-- j{if(!i || !j){i+=j;return ;}if(l==r){s[i].val+=s[j].val;return ;}int mid=(l+r)>>1;merg(s[i].lc,s[j].lc,l,mid);merg(s[i].rc,s[j].rc,mid+1,r);pushup(i);}ll que(int i,int l,int r,int x,int y){if(!i)return 0;if(l>=x && r<=y)return s[i].val;int mid=(l+r)>>1;ll sum=0;if(x<=mid)sum+=que(s[i].lc,l,mid,x,y);if(y>mid)sum+=que(s[i].rc,mid+1,r,x,y);return sum;}void modi(int u){upd(root[u],1,n,dep[u],siz[u]-1);}void mer(int x,int y){merg(root[x],root[y],1,n);}ll ask(int x,int y){if(dep[x]==n)return 0;return que(root[x],1,n,dep[x]+1,min(dep[x]+y,n));}
}T;void ad(int x,int y)
{e[++cnt].n=head[x];e[cnt].y=y;head[x]=cnt;
}void loading(int u,int fa)
{siz[u]=1;dep[u]=dep[fa]+1;for(int i=head[u],v;i;i=e[i].n){v=e[i].y;if(v==fa)continue;loading(v,u);siz[u]+=siz[v];}T.modi(u);
}void init()
{n=read();q=read();for(int i=1,x,y;i<n;++i){x=read();y=read();ad(x,y);ad(y,x);}loading(1,0);ques node;for(int i=1,x,y;i<=q;++i){x=read();y=read();node.id=i;node.x=y;++deg[x];h[x].push_back(node);}
}void collect(int u,int fa)
{for(int i=head[u],v;i;i=e[i].n){v=e[i].y;if(v==fa)continue;collect(v,u);T.mer(u,v);}for(int i=0;i<deg[u];++i){int id=h[u][i].id;ll nw=h[u][i].x;ans[id]+=(siz[u]-1ll)*min(dep[u]-1ll,nw);ans[id]+=(T.ask(u,nw));}
}void work()
{collect(1,0);for(int i=1;i<=q;++i)printf("%lld\n",ans[i]);
}int main()
{init();work();return 0;
}