您现在的位置是:首页 >其他 >再谈分拆数网站首页其他

再谈分拆数

dygxczn 2024-06-17 11:19:00
简介再谈分拆数

建议先阅读我的博客 分拆数简介

上一篇博客讲了时间复杂度为 O ( n log ⁡ n ) O(nlog n) O(nlogn) 的求分拆数的方法,但是这个方法有弊端,它的常数由于求多项式 exp ⁡ exp exp,导致常数巨大,如果模数不是 NTT ext{NTT} NTT 模数,就要用 MTT ext{MTT} MTT,常数更大。因此,这里再介绍一种更简单常数更小的方法求分拆数,时间复杂度 O ( n n ) O(nsqrt n) O(nn )

做法

方法是 dp ext{dp} dp,但是需要一个小技巧:根号分治。

T = n T=sqrt n T=n

如果选的数 ≤ T le T T,直接做背包,时间复杂度 O ( n n ) O(nsqrt n) O(nn )

如果选的数 > T >T >T,我们选的数肯定不会超过 T T T 个。设 dp i , j ext{dp}_{i,j} dpi,j 表示选了 i i i 个数,总和为 j j j 的方案数。此时有两种选择。

  1. 新添加一个数 T + 1 T+1 T+1,转移方程: dp i , j + = dp i − 1 , j − T − 1 ext{dp}_{i,j}+= ext{dp}_{i-1,j-T-1} dpi,j+=dpi1,jT1
  2. 给现有的数都加 1 1 1,转移方程: dp i , j + = dp i , j − i ext{dp}_{i,j}+= ext{dp}_{i,j-i} dpi,j+=dpi,ji

初始状态 dp 0 , 0 = 1 ext{dp}_{0,0}=1 dp0,0=1

可以证明,这样的选择能不重不漏地计算结果。

这样,第一维状态是 n sqrt n n ,第二维是 n n n,时间复杂度 O ( n n ) O(nsqrt n) O(nn )

然后又可以算出选大于 T T T 的数总和为 i i i 的方案数了。

上面只考虑了大于或小于等于 T T T 的情况。要求总的方案数,我们可以枚举 i i i,表示选了大于 T T T 的数之和为 i i i,小于等于 T T T 的数之和为 n − i n-i ni,这时把二者的方案乘起来,对于每个 i i i 再求和就是答案。

预处理时间复杂度 O ( n n ) O(nsqrt n) O(nn ),单次求答案时间复杂度 O ( n ) O(n) O(n)

优缺点分析

  • 优点:思维难度小,代码实现简单,常数小,不受模数限制。
  • 缺点:时间复杂度较大,在某些毒瘤题目可能过不了(比如 n ≤ 1 0 6 nle10^6 n106

例题

HDU4651 Partition
板题,求 n n n 的分拆数,答案模 1 e 9 + 7 1e9+7 1e9+7。多次询问,总数不超过 100 100 100 次。

对于这题,模数对 NTT ext{NTT} NTT 很不友好,所以选择 O ( n n ) O(nsqrt n) O(nn ) dp ext{dp} dp

由于询问次数有限,所以单次查询 O ( n ) O(n) O(n) 可过。

这道题空间 32 M B 32MB 32MB,所以 dp ext{dp} dp 数组要滚动。

代码如下

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1e9+7;
const int N=1e5,T=316;
ll f1[N+1],f2[2][N+1],sum[N+1];
int n;
int main()
{
    f1[0]=f2[0][0]=sum[0]=1;
    for(int i=1;i<=T;i++){
        for(int j=i;j<=N;j++){
            (f1[j]+=f1[j-i])%=mod;
        }
    }
    for(int i=1;i<=(N+T)/(T+1);i++){
        memset(f2[i&1],0,sizeof(f2[i&1]));
        for(int j=i*(T+1);j<=N;j++){
            (f2[i&1][j]+=f2[(i-1)&1][j-T-1]+f2[i&1][j-i])%=mod;
            (sum[j]+=f2[i&1][j])%=mod;
        }
    }
    int t;
    cin>>t;
    while(t--){
        scanf("%d",&n);
        ll ans=0;
        for(int i=0;i<=n;i++) (ans+=sum[i]*f1[n-i])%=mod;
        cout<<ans<<endl;
    }
}

LOJ6268 分拆数
也是一道板题,但是模数 998244353 998244353 998244353,对 N T T NTT NTT 很友好,所以使用多项式方法。

代码如下:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=(1<<18)+1;
const ll mod=998244353,g=3,inv2=499122177;
int len=1,n;
ll a1[N],w,wn,a[N],ans[N],invans[N],lnans[N],da[N],inva[N],omg[N],inv[N];
ll ksm(ll a,ll b)
{
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}
void change(ll num[])
{
    for(int i=1,j=len/2;i<len-1;i++){
        if(i<j) swap(num[i],num[j]);
        int k=len/2;
        while(j>=k) j-=k,k>>=1;
        if(j<k) j+=k;
    }
}
void ntt(ll num[],int fl)
{
    for(int i=2;i<=len;i<<=1){
        if(fl==1) wn=ksm(g,(mod-1)/i);
        else wn=ksm(g,mod-1-(mod-1)/i);
        for(int j=0;j<len;j+=i){
            w=1;
            for(int k=j;k<j+i/2;k++){
                ll u=w*num[k+i/2]%mod,t=num[k];
                num[k]=(t+u)%mod;
                num[k+i/2]=(t-u+mod)%mod;
                w=w*wn%mod;
            }
        }
    }
    if(fl==-1){
        ll inv=ksm(len,mod-2);
        for(int i=0;i<len;i++) num[i]=num[i]*inv%mod;
    }
}
int read()
{
    int sum=0,c=getchar();
    while(c<48||c>57) c=getchar();
    while(c>=48&&c<=57) sum=sum*10+c-48,c=getchar();
    return sum;
}
void getinv(int n,ll a[],ll ans[])
{
	if(n==1){ans[0]=ksm(a[0],mod-2);return;}
	getinv((n+1)/2,a,ans);
	len=1;
	while(len<2*n) len*=2;
	for(int i=0;i<n;i++) a1[i]=a[i];
	for(int i=n;i<len;i++) a1[i]=0;
	change(a1),change(ans);
	ntt(a1,1),ntt(ans,1);
	for(int i=0;i<len;i++) ans[i]=ans[i]*(2-ans[i]*a1[i]%mod+mod)%mod;
	change(ans),ntt(ans,-1);
	for(int i=n;i<len;i++) ans[i]=0;
}
void getln(int n,ll a[],ll ln[])
{
	for(int i=1;i<n;i++) da[i-1]=a[i]*i;
    da[n-1]=0;
    memset(inva,0,sizeof(inva));
	getinv(n,a,inva);
	len=1;
	while(len<2*n) len*=2;
	change(da),change(inva);
	ntt(da,1),ntt(inva,1);
	for(int i=0;i<len;i++) ln[i]=da[i]*inva[i]%mod;
	change(ln),ntt(ln,-1);
	for(int i=len-1;i>=0;i--) ln[i+1]=ksm(i+1,mod-2)*ln[i]%mod;
    for(int i=n;i<len;i++) ln[i]=0;
	ln[0]=0;
}
void getexp(int n,ll a[],ll ans[])
{
    if(n==1){ans[0]=1;return;}
    getexp((n+1)/2,a,ans);
    len=1;
    while(len<2*n) len*=2;
    // memset(lnans,0,sizeof(lnans));
    getln(n,ans,lnans);
    for(int i=0;i<n;i++) lnans[i]=(-lnans[i]+a[i]+mod)%mod;
    lnans[0]++;
    change(ans),change(lnans);
    ntt(ans,1),ntt(lnans,1);
    for(int i=0;i<len;i++) ans[i]=ans[i]*lnans[i]%mod;
    change(ans),ntt(ans,-1);
    for(int i=n;i<len;i++) ans[i]=0;
}
void init(int n)
{
    inv[1]=1;
    for(int i=2;i<=n;i++) inv[i]=inv[mod%i]*(mod-mod/i)%mod;
}
int main()
{
    init(1e5);
	scanf("%d",&n);
    for(int i=1;i<=n;i++){
        for(int j=1;i*j<=n;j++){
            (a[i*j]+=inv[j])%=mod;
        }
    }
	getexp(n+1,a,ans);
	for(int i=1;i<=n;i++) printf("%lld
",ans[i]);

}
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。