[学军20201104CSP模拟赛] 二维码

简要题意:

对于$n\times m$的网格图,初始时全部为白色,现在 通过下面的方法染色

每次选择一个行或者列,把它全部染成黑色或者全部染成白色

求任意操作的情况下,可以得到的不同网格图的数量$\mod 998244353$

判定一个染色方案是否有解的条件是:

染色完成的矩阵不包含一个子矩阵满足四个角分别为

01

10

或者

10

01

但是这样看这个条件似乎比较抽象,如果具体对于一个行上考虑,就是满足

每一行所包含的1的位置的集合之间 互为子集

显然一个方案可以任意交换行/列,不妨把按照每一行1的个数将每一行排序,设每一行有$a_i$个1,边界条件为$a_0=0$

那么对于行上的1考虑排列,方案数为$\begin{aligned} \prod \binom{m-a_{i-1} } {a_i-a_{i-1} }\end{aligned}$,即从空的$m-a_{i-1}$个位置里选出多出的$a_i-a_{i-1}$个位置

而对于列之间的排列需要考虑$a_i$与$a_{i+1}$的关系,因为如果$a_i=a_{i+1}$时,必然满足这两行相同

设所有的$a_i$构成若干个相同的组,每一组包含$b_i(i\in[1,k])$个元素,则方案数显然为$\begin{aligned} \frac{n!} {\prod b_i!}\end{aligned}$

而组内的$a_i$之间显然是没有$\begin{aligned} \sum \binom{m-a_{i-1} } {a_i-a_{i-1} }\end{aligned}$的贡献的,可以跳过

由此,不妨令$dp_{i,j}$表示$dp$了前$i$行,最后一行$a_i=j$,每次枚举每个组$b_i$转移

复杂度为$O(n^4)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
const int N=2010,P=998244353;
int n,m,C[N][N],dp[N][N],I[N],J[N];
ll qpow(ll x,ll k=P-2) {
ll res=1;
for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
return res;
}
int main(){
rep(i,J[0]=1,N-1) J[i]=1ll*J[i-1]*i%P;
I[N-1]=qpow(J[N-1]);
drep(i,N-2,0) I[i]=1ll*I[i+1]*(i+1)%P;
rep(i,0,N-1) rep(j,C[i][0]=1,i) C[i][j]=(C[i-1][j]+C[i-1][j-1])%P;

scanf("%d%d",&n,&m);
rep(i,0,n) dp[i][0]=I[i]; // 第一个块
rep(i,1,n) rep(j,1,m) {
rep(a,0,i-1) rep(b,0,j-1) {
dp[i][j]=(dp[i][j]+1ll*dp[a][b]*I[i-a]%P*C[m-b][j-b])%P;
}
}
int ans=0;
rep(i,0,m) ans=(ans+dp[n][i])%P;
ans=1ll*ans*J[n]%P;
printf("%d\n",ans);
}

优化:

发现两维$dp$之间是相互独立的,分别是把$n,m$分组

令$F_{i,j}$表示把$n$分成了$i$个组,当前总和为$j$的方案数,$G_{i,j}$表示把$m$分成$i$组,当前总和为$j$

按照上面的系数转移,最后$O(n)$合并,复杂度为$O(n^3)$

进一步优化:

为了方便下面的叙述,不妨先整理一下$a_i$之间转移的系数,不妨设边界$a_{n+1}=m$

$\begin{aligned} \prod \binom{m-a_{i-1} } {a_i-a_{i-1} }=\prod_{i=1}^n \frac{(m-a_{i-1})!} {(a_i-a_{i-1})!(m-a_{i})!}= \frac{m!} {\prod_{i=1}^{n+1} (a_{i}-a_{i-1})!}\end{aligned}$

发现实际上和列之间的系数是类似的,每次枚举$a_i-a_{i-1}$即可

而实际上只有$k$个$b_i$直接相交的位置$a_{i}-a_{i-1}$有效,因此行和列实际上分别是将$n,m$分成了$k$组

观察上面的转移系数,行构成的块,首个块大小可以为$0$,而列构成的块最后一个块大小可以为$0$,所以这个并不是严格分成$k$组,下面会讨论这个问题

我们计算答案的复杂度消耗在计算分成若干块的方案,而实际上,把$n$分成$k$块的方案数就是$\begin{Bmatrix} n\\ k\end{Bmatrix}\cdot k!$

用$n^2$递推第二类斯特林数的方法即可计算

对于并不是严格分成$k$组的问题,可以考虑把开头/结尾那一个大小为$0$的块删掉,即同时还要考虑$\begin{Bmatrix}n \\ k-1\end{Bmatrix}(k-1)!$

最后再枚举$k$,复杂度为$O(n^2)$

更优化的就是$\text{NTT}$计算斯特林数,带入通项公式

$\begin{aligned} \begin{Bmatrix}n\\ m\end{Bmatrix}m!=\sum_{i=0}^m i^n(-1)^{m-i}\binom{m} {i} \end{aligned}$

显然把组合数拆开$\text{NTT}$即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
#define Mod1(x) ((x>=P)&&(x-=P))
#define Mod2(x) ((x<0)&&(x+=P))
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)

const int N=1<<18|10,P=998244353;

int n,m,ans;
ll qpow(ll x,ll k=P-2) {
ll res=1;
for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
return res;
}

int I[N],J[N];

int rev[N];
int Init(int n){
int R=2,c=0;
while(R<=n) R<<=1,c++;
rep(i,0,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<c);
return R;
}

void NTT(int n,int *a,int f){
static int e[N>>1];
rep(i,1,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=e[0]=1;i<n;i<<=1) {
ll t=qpow(f==1?3:(P+1)/3,(P-1)/i/2);
for(int j=i-2;j>=0;j-=2) e[j+1]=(e[j]=e[j>>1])*t%P;
for(int l=0;l<n;l+=i*2) {
for(int j=l;j<l+i;++j) {
int t=1ll*a[j+i]*e[j-l]%P;
a[j+i]=a[j]-t,Mod2(a[j+i]);
a[j]+=t,Mod1(a[j]);
}
}
}
if(f==-1) {
ll base=qpow(n);
rep(i,0,n-1) a[i]=a[i]*base%P;
}
}

int A[N],B[N],C[N];

int main(){
scanf("%d%d",&n,&m);
if(n<m) swap(n,m);
rep(i,J[0]=1,N-1) J[i]=1ll*J[i-1]*i%P;
I[N-1]=qpow(J[N-1]);
drep(i,N-2,0) I[i]=1ll*I[i+1]*(i+1)%P;

int R=Init(n+n+2);
rep(i,0,n) A[i]=qpow(i,n)*I[i]%P;
rep(i,0,n) B[i]=(i&1)?P-I[i]:I[i];
NTT(R,A,1),NTT(R,B,1);
rep(i,0,R-1) A[i]=1ll*A[i]*B[i]%P;
NTT(R,A,-1);
rep(i,n+1,R) A[i]=0;

rep(i,0,m) C[i]=qpow(i,m)*I[i]%P;
NTT(R,C,1);
rep(i,0,R-1) C[i]=1ll*C[i]*B[i]%P;
NTT(R,C,-1);
rep(i,m+1,R) C[i]=0;

int ans=0;
rep(i,0,min(n,m)) ans=(ans+1ll*(1ll*A[i]*J[i]%P+1ll*A[i+1]*J[i+1]%P)*(1ll*C[i]%P*J[i]%P+1ll*C[i+1]*J[i+1]%P))%P;
printf("%d\n",ans);
}