「学习笔记」二项式反演

概念

二项式反演其实就是利用容斥的思想处理一些通过求“至少或至多”来解决“恰好”的问题。

形式

f(n)=i=0n(1)i(ni)g(i)    g(n)=i=0n(1)i(ni)f(i)f(n)=i=0n(ni)g(i)    g(n)=i=0n(1)ni(ni)f(i)f(n)=i=nm(ni)g(i)    g(n)=i=nm(1)in(ni)f(i)\begin{aligned} f(n)=\sum_{i=0}^n(-1)^i\binom n i g(i)&\iff g(n)=\sum_{i=0}^n(-1)^i\binom n i f(i) \\ f(n)=\sum_{i=0}^n\binom n i g(i)&\iff g(n)=\sum_{i=0}^n(-1)^{n-i}\binom n i f(i) \\ f(n)=\sum_{i=n}^m \binom n i g(i)&\iff g(n)=\sum_{i=n}^m(-1)^{i-n}\binom n i f(i) \end{aligned}

其中,形式三比较常用,组合意义f(n)f(n) 表示“至少选 nn 个”,g(n)g(n) 表示“恰好选 nn 个”。

例题

Luogu P4859 已经没有什么好害怕的了

Link

Description

给定两个长为 nn 的序列 a,ba,b,它们两两配对,求配对后 a>ba>b 的组数比 b>ab>a 的组数恰好多 kk 组的方案数。

1n2000,0kn1\le n \le 2000,0\le k\le n

Solution

题目要求“恰好多 kk 组”,共有 nn 组,所以相当于 a>ba>b 恰好 n+k2\dfrac {n+k}2 组。

dpi,jdp_{i,j} 表示前 ii 个数中,有 jja>ba>b 的方案数,转移方程为

dpi,j=dpi1,j+dpi1,j1×(cnti(j1))dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times(cnt_i-(j-1))

其中,cnticnt_i 表示 bb 中比 aia_i 小的数的个数,这个可以将 a,ba,b 排序后双指针扫

接下来,记 fi=dpn,i×(ni)!f_i=dp_{n,i}\times (n-i)!,也就是至少 ii 组的方案数

然后根据二项式反演就可以得到恰好 kk 组的方案数 gkg_k

gk=i=kn(1)ni(ik)fig_k=\sum_{i=k}^n(-1)^{n-i}\binom i k f_i

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
int n, k, a[N], b[N], cnt[N];
ll fac[N], dp[N][N], f[N], g[N];

ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
ll add(ll x) {return x < mod ? x : x - mod;}
ll inv(ll x) {return qpow(x, mod - 2);}
ll C(int n, int m) {return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;}

int main()
{
read(n), read(k);

if((n + k) & 1)
{
puts("0");
return 0;
}
k = (n + k) >> 1;

for(int i = 1; i <= n; i++) read(a[i]);
for(int i = 1; i <= n; i++) read(b[i]);
sort(a + 1, a + 1 + n);
sort(b + 1, b + 1 + n);

fac[0] = 1;
for(int i = 1, j = 1; i <= n; i++)
{
while(j <= n && a[i] > b[j]) j++;
cnt[i] = j - 1;
fac[i] = fac[i - 1] * i % mod;
}

dp[0][0] = 1;
for(int i = 1; i <= n; i++)
for(int j = 0; j <= i; j++)
dp[i][j] = add(dp[i - 1][j] + (!j ? 0 : dp[i - 1][j - 1] * (cnt[i] - j + 1) % mod));
for(int i = 0; i <= n; i++) f[i] = dp[n][i] * fac[n - i] % mod;
for(int i = 1; i <= n; i++)
for(int j = k; j <= n; j++)
g[i] = add(g[i] + add((((j - k) & 1) ? -1 : 1) * f[j] * C(j, k) % mod + mod));

write(g[k]), pc('\n');
return 0;
}
// A.S.

Luogu P4491 [HAOI2018]染色

Link

Description

有一个长为 nn 的序列,每个位置都可以是 [1,m][1,m] 中的某一个数,若这 nn 个数中恰好出现了 ss 次的数有 kk 个,那么会得到 wkw_k 的贡献。

求对于所有可能的情况,能获得的权值的和对 10045358091004535809 取模的结果是多少。

1n107,1m105,0s150,0wi10045358091\le n\le 10^7,1\le m \le 10^5,0\le s\le 150,0\le w_i\le 1004535809

Solution

显然数的个数不会超过 cnt=min(m,n/s)cnt=\min(m,n/s)

依然是恰好出现 ss 次,考虑计算有 ii 个数至少出现 ss 次的方案数 fif_i

钦定 ii 个数出现了 ss 次,剩下的 nisn-is 个位置在 mim-i 个数中随便选

fi=(mi)n!(s!)i(nis)!(mi)nisf_i=\binom m i \dfrac{n!}{(s!)^i(n-is)!}(m-i)^{n-is}

然后进行二项式反演,设 gkg_k 表示有 kk 个数恰好出现 ss

gk=i=km(1)ik(ik)figk×k!=(1)iki!(ik)!fi\begin{aligned} g_k&=\sum_{i=k}^m(-1)^{i-k}\binom i k f_i \\ g_k\times k!&=\sum(-1)^{i-k}\dfrac{i!}{(i-k)!}f_i \end{aligned}

到这里就能看出来卷积的形式了

F(x)=i=0mfi×i!G(x)=i=0m(1)ii!F(x)=\sum_{i=0}^mf_i\times i! \\ G(x)=\sum_{i=0}^m\dfrac{(-1)^i}{i!}

那么 gi=(FG)(i)i!g_i=\dfrac{(F*G)(i)}{i!}

NTT 计算卷积即可

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar

using namespace std;

namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}

template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;

const int MAXN = 1e7 + 5;
const int N = 1e5 + 5;
const int mod = 1004535809;
const int G = 3;
const int Gi = 334845270;

ll add(ll x) {return x < mod ? x : x - mod;}
ll sub(ll x) {return x < 0 ? x + mod : x;}
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
ll inv(ll x) {return qpow(x, mod - 2);}

ll fac[MAXN], f[N << 2], g[N << 2];

ll C(int n, int m)
{
return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;
}

int rev[N << 2];

int calclim(int n)
{
int lim = 1;
while(lim < n) lim <<= 1;
return lim;
}

void calcrev(int lim)
{
for(int i = 0; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
}

void NTT(ll *a, int lim, int type)
{
for(int i = 0; i < lim; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < lim; mid <<= 1)
{
ll wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
for(int i = 0; i < lim; i += (mid << 1))
{
ll w = 1;
for(int j = 0; j < mid; j++, w = w * wn % mod)
{
ll x = a[i + j], y = w * a[i + mid + j] % mod;
a[i + j] = add(x + y);
a[i + mid + j] = sub(x - y);
}
}
}
if(type == -1)
{
ll limi = qpow(lim, mod - 2);
for(int i = 0; i < lim; i++) a[i] = a[i] * limi % mod;
}
return;
}

int main()
{
int n, m, s;
read(n), read(m), read(s);
int cnt = min(m, n / s) + 1;
fac[0] = 1;
for(int i = 1; i < MAXN; i++)
fac[i] = fac[i - 1] * i % mod;
for(int i = 0; i < cnt; i++)
{
f[i] = fac[i] * C(m, i) % mod * fac[n] % mod * inv(qpow(fac[s], i)) % mod * inv(fac[n - s * i]) % mod * qpow(m - i, n - s * i) % mod;
g[i] = (i & 1) ? mod - inv(fac[i]) : inv(fac[i]);
}

reverse(f, f + cnt);
int lim = calclim(cnt << 1);
calcrev(lim);
NTT(f, lim, 1), NTT(g, lim, 1);
for(int i = 0; i < lim; i++) f[i] = f[i] * g[i] % mod;
NTT(f, lim, -1);
reverse(f, f + cnt);

ll ans = 0;
for(int i = 0, w; i < cnt; i++)
read(w), ans = add(ans + inv(fac[i]) * f[i] % mod * w % mod);
write(ans), pc('\n');

return 0;
}
// A.S.