[WC2019]数树(树形dp+多项式exp)

Part1

相同边连接的点同一颜色,直接模拟即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
namespace pt1{
int fa[N],sz[N];
map <int,int> M[N];
int Find(int x){ return fa[x]==x?x:fa[x]=Find(fa[x]); }
void Solve(){
rep(i,1,n) fa[i]=i;
rep(i,2,n){
int x=rd(),y=rd();
if(x>y) swap(x,y);
M[x][y]=1;
}
rep(i,2,n) {
int x=rd(),y=rd();
if(x>y) swap(x,y);
if(M[x][y]) fa[Find(x)]=Find(y);
}
int ans=1;
rep(i,1,n) if(Find(i)==i) ans=1ll*ans*y%P;
printf("%d\n",ans);
}
}

Part2

相同边连接的点同一颜色,即在相同边构成的树上形成了若干联通块

很容易想到可以强制一些边保留,设保留$i$条边的方案数是$F_i$,则答案就是$\sum_i F_i\cdot y^{n-i}$

考虑$dp$那些边相同,但是不好直接计算剩下边不同的方案,所以考虑计算最多有$i$条边相同的方案数,即

二项式反演得到$F_i=\sum_{j=i}(-1)^{j-i}C(j,i)G_j$

设分成了$m$个联通块,大小分别为$size_i$,则这些联通块随意构成树的方案数就是$n^{m-2}\cdot\prod size_i$

根据上述性质可以写出一个简单的$O(n^4)$树形dp求得$G_i$,即$dp[i][j][k]$表示在$i$的子树里,有$j$条边相同,当前还剩下一个大小为$k$的联通块,每多转移一条相同边,系数是$\frac{1} {ny}$

考虑优化$dp$

1.

联通块大小的问题,可以转化为每次在联通块里选择一个关键点的方案数,$dp$第三维$0/1$表示当前联通块里是否已经选出了关键点

每次断开一个联通块时必须已经存在关键点

2.

答案是

$\sum_i F_i\cdot y^{n-i}$

$=\sum_i y^{n-i} \sum_{j=i}(-1)^{j-i}C(j,i)G_j$

$=y^n G_j\sum_{i=0}^j(-1)^{j-i}C(j,i)y^{-i}$

发现右边的式子$\sum_0^j(-1)^{j-i}C(j,i)y^{-i}=(\frac{1} {y}-1)^j$

那么直接把$\frac{1} {y}-1$带入作为保留一条边的转移系数,消去了第二维

那么这个$\text{dp}$可以被优化到$O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
namespace pt2{
vector <int> G[N];
int dp[N][2],g[2],Inv;
void dfs(int u,int f){
dp[u][0]=dp[u][1]=1;
for(int v:G[u]) if(v!=f) {
dfs(v,u);
g[0]=g[1]=0;
rep(i,0,1) rep(j,0,1) {
if(!i||!j) g[i|j]=(g[i|j]+1ll*dp[u][i]*dp[v][j]%P*Inv)%P;
if(j) g[i]=(g[i]+1ll*dp[u][i]*dp[v][j])%P;
}
dp[u][0]=g[0],dp[u][1]=g[1];
}
}
void Solve() {
rep(i,2,n) {
int u=rd(),v=rd();
G[u].pb(v),G[v].pb(u);
}
Inv=(qpow(y)-1)*qpow(n)%P;
dfs(1,0);
ll res=dp[1][1]*qpow(y,n)%P*qpow(n,P+n-3)%P;
printf("%lld\n",res);
}
}

Part3

有了上面的$dp$,这一部分就简单多了,设分成了$m$个联通块,每个大小为$a_i$,则贡献为

$$\begin{aligned}\frac{n!\cdot a_i^{a_i-2}\cdot (n^{m-2})^2(\frac{1} {y}-1)^{n-m}(\frac{1} {n}^{n-m})^2\cdot a_i^2} {\prod a_i! m !}\end{aligned}$$

即枚举每个联通块生成树的数量,且需要考虑两棵树分别的联通块之间的连边数量,这一部分需要平方

很显然,可以直接对于$[x^i]F(x)=\frac{1} {i!}\cdot (\frac{1} {n^2}\cdot (\frac{1} {y}-1))^{i-1} i^2i^{i-2}$这个多项式求exp得到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
const int M=1<<18|10,K=17;
typedef vector <int> Poly;

int w[M],rev[M],Inv[M];
void Init(){
ll t=qpow(3,(P-1)>>K>>1);
w[1<<K]=1;
rep(i,(1<<K)+1,(1<<(K+1))-1) w[i]=w[i-1]*t%P;
drep(i,(1<<K)-1,1) w[i]=w[i<<1];
Inv[0]=Inv[1]=1;
rep(i,2,M-1) Inv[i]=1ll*(P-P/i)*Inv[P%i]%P;
}
int Init(int n){
int R=1,cc=-1;
while(R<n) R<<=1,cc++;
rep(i,1,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<cc);
return R;
}

void NTT(int n,Poly &a,int f){
if((int)a.size()<n) a.resize(n);
rep(i,1,n-1) if(rev[i]<i) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1) {
int *e=w+i;
for(int l=0;l<n;l+=i*2){
for(int j=l;j<l+i;++j){
int t=1ll*a[j+i]*e[j-l]%P;
a[j+i]=a[j]-t,Mod2(a[j+i]);
a[j]+=t,Mod1(a[j]);
}
}
}
if(f==-1) {
reverse(a.begin()+1,a.end());
rep(i,0,n-1) a[i]=1ll*a[i]*Inv[n]%P;
}
}

Poly operator * (Poly a,Poly b){
int n=a.size(),m=b.size();
int R=Init(n+m-1);
NTT(R,a,1),NTT(R,b,1);
rep(i,0,R-1) a[i]=1ll*a[i]*b[i]%P;
NTT(R,a,-1),a.resize(n+m-1);
return a;
}

Poly Poly_Inv(Poly a){
int n=a.size();
if(n==1) return {(int)qpow(a[0])};
Poly b=a; b.resize((n+1)/2),b=Poly_Inv(b);
int R=Init(n*2);
NTT(R,a,1),NTT(R,b,1);
rep(i,0,R-1) a[i]=1ll*b[i]*(2-1ll*a[i]*b[i]%P+P)%P;
NTT(R,a,-1); a.resize(n);
return a;
}

Poly Deri(Poly a){
rep(i,1,a.size()-1) a[i-1]=1ll*i*a[i]%P;
a.pop_back();
return a;
}
Poly IDeri(Poly a){
a.pb(0);
drep(i,a.size()-2,0) a[i+1]=1ll*a[i]*Inv[i+1]%P;
a[0]=0;
return a;
}

Poly Ln(Poly a){
int n=a.size();
a=Deri(a)*Poly_Inv(a),a.resize(n+1);
return IDeri(a);
}

Poly Exp(Poly a){
int n=a.size();
if(n==1) return Poly{1};
Poly b=a; b.resize((n+1)/2),b=Exp(b);
b.resize(n); Poly c=Ln(b);
rep(i,0,n-1) c[i]=a[i]-c[i],Mod2(c[i]);
c[0]++,c=c*b;
c.resize(n);
return c;
}

void Solve() {
int I=(qpow(y)-1)*qpow(1ll*n*n%P)%P;
Init();
Poly F(n+1);
for(int i=1,FInv=1;i<=n;FInv=1ll*FInv*Inv[++i]%P){
F[i]=qpow(I,(i-1)) * // 保留i-1条边
(i==1?1:qpow(i,i-2))%P // i个点生成树
* i%P * i%P //
* FInv%P; // 阶乘常数
}
F=Exp(F);
rep(i,1,n) F[n]=1ll*F[n]*i%P;
ll res=F[n]*qpow(y,n)%P*qpow(n,2*(P+n-3))%P;
printf("%lld\n",res);
}