1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| #include<bits/stdc++.h> using namespace std;
typedef long long ll; #define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i) #define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i) #define pb push_back #define Mod1(x) ((x>=P)&&(x-=P)) #define Mod2(x) ((x<0)&&(x+=P))
char IO; int rd(){ int s=0; while(!isdigit(IO=getchar())); do s=(s<<1)+(s<<3)+(IO^'0'); while(isdigit(IO=getchar())); return s; }
const int N=1<<17,P=998244353;
int n,m,k; int Inv[N+1],Fac[N+1],FInv[N+1]; ll C(int n,int m){ return n<0||m<0||n<m? 0 : 1ll*Fac[n]*FInv[m]%P*FInv[n-m]%P; } ll qpow(ll x,ll k=P-2){ ll res=1; for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P; return res; } typedef vector <int> Poly; int w[N|10],rev[N]; void Init(){ Inv[0]=Inv[1]=Fac[0]=Fac[1]=FInv[0]=FInv[1]=1; rep(i,2,N){ Fac[i]=1ll*Fac[i-1]*i%P; Inv[i]=1ll*(P-P/i)*Inv[P%i]%P; FInv[i]=1ll*FInv[i-1]*Inv[i]%P; } w[N>>1]=1; ll t=qpow(3,(P-1)/N); rep(i,(N>>1)+1,N-1) w[i]=w[i-1]*t%P; drep(i,(N>>1)-1,1) w[i]=w[i<<1]; }
int Init(int n){ int R=1,cc=-1; while(R<n) R<<=1,cc++; rep(i,1,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<cc); return R; } void NTT(int n,Poly &a,int f){ if((int)a.size()<n) a.resize(n); rep(i,1,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int i=1;i<n;i<<=1){ int *e=w+i; for(int l=0;l<n;l+=i*2){ for(int j=l;j<l+i;++j) { int t=1ll*a[j+i]*e[j-l]%P; a[j+i]=a[j]-t; Mod2(a[j+i]); a[j]+=t; Mod1(a[j]); } } } if(f==-1){ reverse(a.begin()+1,a.end()); rep(i,0,n-1) a[i]=1ll*a[i]*Inv[n]%P; } }
Poly operator * (Poly a,Poly b){ int n=a.size()+b.size()-1,R=Init(n); NTT(R,a,1),NTT(R,b,1); rep(i,0,R-1) a[i]=1ll*a[i]*b[i]%P; NTT(R,a,-1),a.resize(n); return a; }
Poly Solve(int l,int r){ if(l==r){ int x=rd(); Poly F(x+1); rep(y,1,x) F[y]=C(x-1,y-1)*FInv[y]%P; return F; } int mid=(l+r)>>1; return Solve(l,mid)*Solve(mid+1,r); }
int main(){ freopen("magic.in","r",stdin),freopen("magic.out","w",stdout); Init(),n=rd(),m=rd(),k=rd(); Poly dp=Solve(1,n); rep(i,n,m) dp[i]=1ll*dp[i]*Fac[i]%P; int i=m-k; rep(j,n,i-1) dp[i]=(dp[i]+(((i-j)&1)?-1:1)*dp[j]*C(m-j,m-i)%P+P)%P; ll ans=(dp[i]%P+P)%P; printf("%lld\n",ans); }
|