Minimax:线段树合并优化 dp 好题。
树形 dp
因为要求出每一个值的出现概率,首先我们可以想到一个很暴力的 dp 式子。
定义 \(dp_{i,j}\) 表示在节点 \(i\) 时,权值 \(j\) 的出现概率,设 \(l\) 表示左儿子,\(r\) 表示右儿子,则有如下转移:
- 当 \(j\) 在左儿子中时,\(dp_{i,j}\gets dp_{l,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p_i)\times\sum_{k=j+1}^{V}dp_{r,k})\),理解的话就是对父亲节点选大的还是选小的进行分讨。
- 当 \(j\) 在右儿子中时,\(dp_{i,j}\gets dp_{r,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p_i)\times\sum_{k=j+1}^Vdp_{l,k})\)。
直接转移即可,时间复杂度 \(O(nV)\)。
线段树合并优化
显然原来的时间复杂度会炸掉,但是我们发现每个节点最开始时最多只有一个 dp 位置是有值的,所以我们考虑用这种均摊复杂度的线段树合并来优化这个 dp。
因为 dp 转移的时候需要用到前缀和后缀和,所以我们进行 merge 的时候记录节点 \(x,y\) 的前缀和 \(px,py\) 与后缀和 \(sx,sy\) 以及父亲节点的概率 \(p\)。
梳理一下 merge 的流程:
- 进入节点 \(x,y\)。
- 如果 \(x,y\) 其中之一是空树,则说明直接更新 dp 值即可。
- 若 \(x\) 是空树,对应着上述 \(j\) 在右儿子中的转移方式,则我们对 \(y\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p)\times\sum_{k=j+1}^Vdp_{l,k})=(p\times px+(1-p)\times sx)\) 即可。这个可以用懒标记实现区间乘。
- 若 \(y\) 是空树,对应着上述 \(j\) 在左儿子中的转移方式,则我们对 \(x\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p)\times\sum_{k=j+1}^Vdp_{r,k})=(p\times py+(1-p)\times sy)\) 即可。这个可以用懒标记实现区间乘。
- 否则就说明要递归合并,递归左右儿子的时候记得更新 \(sx,sy,px,py\) 的值。
- 最后将左右儿子的 dp 值加起来就是这个区间的 dp 值。
时间复杂度 \(O(n\log n)\)。
代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N=300005;
const ll mod=998244353;
int n,fa[N],m=0,b[N],son[N][2],cd[N],p[N],ans[N];
ll qpow(ll a,ll b)
{ll res=1;while(b){if(b&1)res=(res*a)%mod;b>>=1;a=(a*a)%mod;}return res;
}
int getrk(int x)
{return (lower_bound(b+1,b+m+1,x)-b);
}
struct Node{int ls,rs;ll dp,tag=1;
};
struct Segtree{Node tr[20*N];int root[N],tot=0;void pushup(int p){tr[p].dp=(tr[lc(p)].dp+tr[rc(p)].dp)%mod;}void pushdown(int p){if(tr[p].tag!=1){tr[lc(p)].tag=(tr[lc(p)].tag*tr[p].tag)%mod;tr[rc(p)].tag=(tr[rc(p)].tag*tr[p].tag)%mod;tr[lc(p)].dp=(tr[lc(p)].dp*tr[p].tag)%mod;tr[rc(p)].dp=(tr[rc(p)].dp*tr[p].tag)%mod;}tr[p].tag=1;}void modify(int p,int v){tr[p].dp=(tr[p].dp*1ll*v)%mod;tr[p].tag=(tr[p].tag*1ll*v)%mod;}void update(int &u,int ln,int rn,int x,ll k){if(u==0)u=++tot;if(ln==rn){tr[u].dp+=k;return;}int mid=(ln+rn)>>1;if(x<=mid)update(lc(u),ln,mid,x,k);else update(rc(u),mid+1,rn,x,k);pushup(u);}int merge(int x,int y,int px,int py,int sx,int sy,int p){if(x==0&&y==0)return 0;if(x==0){modify(y,(1ll*p*px%mod+1ll*((1-p)%mod+mod)%mod*sx)%mod);return y;}if(y==0){modify(x,(1ll*p*py%mod+1ll*((1-p)%mod+mod)%mod*sy)%mod);return x;}pushdown(x);pushdown(y);int lx=tr[lc(x)].dp,rx=tr[rc(x)].dp,ly=tr[lc(y)].dp,ry=tr[rc(y)].dp;tr[x].ls=merge(lc(x),lc(y),px,py,(sx+rx)%mod,(sy+ry)%mod,p);tr[x].rs=merge(rc(x),rc(y),(px+lx)%mod,(py+ly)%mod,sx,sy,p);pushup(x);return x;}void query(int u,int ln,int rn){if(ln==rn){ans[ln]=tr[u].dp;return;}int mid=(ln+rn)>>1;pushdown(u);query(lc(u),ln,mid);query(rc(u),mid+1,rn);}
}tr1;
void dfs1(int u)
{if(son[u][0]==0){tr1.update(tr1.root[u],1,m,getrk(p[u]),1);return;}if(son[u][1]==0){dfs1(son[u][0]);tr1.root[u]=tr1.root[son[u][0]];return;}dfs1(son[u][0]);dfs1(son[u][1]);tr1.root[u]=tr1.merge(tr1.root[son[u][0]],tr1.root[son[u][1]],0,0,0,0,p[u]);
}
int main()
{//freopen("sample.in","r",stdin);//freopen("sample.out","w",stdout);ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin>>n;for(int i=1;i<=n;i++)cin>>fa[i];for(int i=1;i<=n;i++){son[fa[i]][cd[fa[i]]]=i;cd[fa[i]]++;}for(int i=1;i<=n;i++){cin>>p[i];if(cd[i])p[i]=p[i]*1ll*qpow(10000,mod-2)%mod;else b[++m]=p[i];}sort(b+1,b+m+1);m=unique(b+1,b+m+1)-b-1;dfs1(1);tr1.query(tr1.root[1],1,m);ll res=0;for(int i=1;i<=m;i++)res=(res+1ll*i*b[i]%mod*ans[i]%mod*ans[i]%mod)%mod;cout<<res;return 0;
}