CF1111E - Tree

题目大意

给定一棵无根树$T$,$q$次查询每次查询一个给定一个根$r$,点集$S$和限制$m$

求将$S$分成不超过$m$个非空集合,使得最终每个集合内不存在两点为祖先关系


分析

容易发现题目是一个给定部分点集的树形$dp$,因此需要用虚树来处理

将$r$也加入虚树,从$r$开始$\text{dfs}$即确定了根为$r$

dp部分

一种思路是树形背包,计算子树内分为$i$个集合的方案数,枚举在$\text{LCA}$处合并两个集合

但是由于要枚举合并的个数,难以写出优秀的复杂度

由于一条祖先链上点之间的集合独立,容易描述,因此可以考虑$\text{dfs}$序dp

按照$\text{dfs}$依次加入每一个点$u$,令$dp_i$表示当前有$i$个集合的方案数

则在$i$个集合中包含$dep_u$个集合$u$无法加入

枚举$i$,加入一个点$O(m)$转移,滚动数组即可

复杂度为$O(n\log n+\sum |S|\cdot m)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
const int N=1e5+10,P=1e9+7;

int n,m;
int L[N],top[N],son[N],sz[N],dfn,fa[N],dep[N];
struct Edge{
int to,nxt;
} e[N<<1];
int head[N],ecnt;
void AddEdge(int u,int v){
e[++ecnt]=(Edge){v,head[u]};
head[u]=ecnt;
}

void dfs(int u){
sz[u]=1;
for(int i=head[u];i;i=e[i].nxt) {
int v=e[i].to;
if(v==fa[u]) continue;
dep[v]=dep[fa[v]=u]+1,dfs(v);
if(sz[v]>sz[son[u]]) son[u]=v;
sz[u]+=sz[v];
}
}
void dfs(int u,int t){
top[u]=t,L[u]=++dfn;
if(son[u]) dfs(son[u],t);
for(int i=head[u];i;i=e[i].nxt) {
int v=e[i].to;
if(v==son[u] || v==fa[u]) continue;
dfs(v,v);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]) dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];
return dep[x]<dep[y]?x:y;
}

vector <int> G[N];
int stk[N],T;
void Link(int u,int v){ G[u].pb(v),G[v].pb(u); }
void Ins(int u){
if(T<=1) return void(stk[++T]=u);
int lca=LCA(u,stk[T]);
if(lca==stk[T]) return void(stk[++T]=u);
while(T>1 && L[stk[T-1]]>=L[lca]) Link(stk[T],stk[T-1]),T--;
if(stk[T]!=lca) Link(stk[T],lca),stk[T]=lca;
stk[++T]=u;
}
int dis[N],mk[N];

int dp[310],a[N],c,rt;
void dfs_dp(int u,int f){
dis[u]=dis[f]+mk[u];
if(mk[u]) drep(i,m,0) dp[i]=((i?dp[i-1]:0)+1ll*(i-dis[f])*dp[i])%P;
for(int v:G[u]) if(v!=f) dfs_dp(v,u);
mk[u]=0,G[u].clear();
}

int main(){
n=rd(),m=rd();
rep(i,2,n){
int u=rd(),v=rd();
AddEdge(u,v),AddEdge(v,u);
}
dfs(1),dfs(1,1);
rep(_,1,m) {
c=rd(),m=rd(),rt=rd();
rep(i,1,c) mk[a[i]=rd()]=1;
a[++c]=1,a[++c]=rt;
sort(a+1,a+c+1,[&](int x,int y){ return L[x]<L[y]; });
T=0;
rep(i,1,c) if(a[i]!=a[i-1]) Ins(a[i]);
while(T>1) Link(stk[T-1],stk[T]),T--;
rep(i,0,m) dp[i]=0;
dp[0]=1;
dfs_dp(rt,0);
int ans=0;
rep(i,1,m) ans+=dp[i],Mod1(ans);
printf("%d\n",ans);
}
}