P5664 Emiya 今天家里的饭
2023-10-18 19:42:47 # OI # Problem

CSP 2019 D2 T1

题面

题目描述

Emiya 是个擅长做菜的高中生,他共掌握 nn烹饪方法,且会使用 mm主要食材做菜。为了方便叙述,我们对烹饪方法从 1n1 \sim n 编号,对主要食材从 1m1 \sim m 编号。

Emiya 做的每道菜都将使用恰好一种烹饪方法与恰好一种主要食材。更具体地,Emiya 会做 ai,ja_{i,j} 道不同的使用烹饪方法 ii 和主要食材 jj 的菜(1in1 \leq i \leq n1jm1 \leq j \leq m),这也意味着 Emiya 总共会做 i=1nj=1mai,j\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{m} a_{i,j} 道不同的菜。

Emiya 今天要准备一桌饭招待 Yazid 和 Rin 这对好朋友,然而三个人对菜的搭配有不同的要求,更具体地,对于一种包含 kk 道菜的搭配方案而言:

  • Emiya 不会让大家饿肚子,所以将做至少一道菜,即 k1k \geq 1
  • Rin 希望品尝不同烹饪方法做出的菜,因此她要求每道菜的烹饪方法互不相同
  • Yazid 不希望品尝太多同一食材做出的菜,因此他要求每种主要食材至多在一半的菜(即 k2\lfloor \frac{k}{2} \rfloor 道菜)中被使用

这里的 x\lfloor x \rfloor 为下取整函数,表示不超过 xx 的最大整数。

这些要求难不倒 Emiya,但他想知道共有多少种不同的符合要求的搭配方案。两种方案不同,当且仅当存在至少一道菜在一种方案中出现,而不在另一种方案中出现。

Emiya 找到了你,请你帮他计算,你只需要告诉他符合所有要求的搭配方案数对质数 998,244,353998,244,353 取模的结果。

输入格式

第 1 行两个用单个空格隔开的整数 n,mn,m

第 2 行至第 n+1n + 1 行,每行 mm 个用单个空格隔开的整数,其中第 i+1i + 1 行的 mm 个数依次为 ai,1,ai,2,,ai,ma_{i,1}, a_{i,2}, \cdots, a_{i,m}

输出格式

仅一行一个整数,表示所求方案数对 998,244,353998,244,353 取模的结果。

样例 #1

样例输入 #1

2 3 
1 0 1
0 1 1

样例输出 #1

3

样例 #2

样例输入 #2

3 3
1 2 3
4 5 0
6 0 0

样例输出 #2

190

样例 #3

样例输入 #3

5 5
1 0 0 1 1
0 1 0 1 0
1 1 1 1 0
1 0 1 0 1
0 1 1 0 1

样例输出 #3

742

提示

【样例 1 解释】

由于在这个样例中,对于每组 i,ji, j,Emiya 都最多只会做一道菜,因此我们直接通过给出烹饪方法、主要食材的编号来描述一道菜。

符合要求的方案包括:

  • 做一道用烹饪方法 1、主要食材 1 的菜和一道用烹饪方法 2、主要食材 2 的菜
  • 做一道用烹饪方法 1、主要食材 1 的菜和一道用烹饪方法 2、主要食材 3 的菜
  • 做一道用烹饪方法 1、主要食材 3 的菜和一道用烹饪方法 2、主要食材 2 的菜

因此输出结果为 3mod998,244,353=33 \bmod 998,244,353 = 3。 需要注意的是,所有只包含一道菜的方案都是不符合要求的,因为唯一的主要食材在超过一半的菜中出现,这不满足 Yazid 的要求。

【样例 2 解释】

Emiya 必须至少做 2 道菜。

做 2 道菜的符合要求的方案数为 100。

做 3 道菜的符合要求的方案数为 90。

因此符合要求的方案数为 100 + 90 = 190。

【数据范围】

测试点编号 n=n= m=m= ai,j<a_{i,j}< 测试点编号 n=n= m=m= ai,j<a_{i,j}<
11 22 22 22 77 1010 22 10310^3
22 22 33 22 88 1010 33 10310^3
33 55 22 22 9129\sim 12 4040 22 10310^3
44 55 33 22 131613\sim 16 4040 33 10310^3
55 1010 22 22 172117\sim 21 4040 500500 10310^3
66 1010 33 22 222522\sim 25 100100 2×1032\times 10^3 998244353998244353

对于所有测试点,保证 1n1001 \leq n \leq 1001m20001 \leq m \leq 20000ai,j<998,244,3530 \leq a_{i,j} \lt 998,244,353

思路

本题要求满足两个条件:每行选不超过一个,每列选不超过当前选择的一半。

观察到最多有一列超过已选择的一半(显然)。

容斥:每行不超过一个的总方案 - 每行不超过一个且有一列不满足条件的方案数 = 总方案。

有一列不满足条件的方案数:

枚举每一列,设当前枚举到 curcur,状态设计为 fi,j,kf_{i, j, k} 代表前 iicurcur 列选了 jj 个其它列选了 kk 个的方案数,那么转移就是:

fi,j,k=fi1,j,k+fi1,j1,k×ai,cur+fi1,j,k1×(sumiai,cur)f_{i, j, k} = f_{i-1,j,k} + f_{i-1,j-1,k}\times a_{i,cur}+f_{i-1,j,k-1}\times(sum_i-a_{i,cur})

统计:j>kfn,j,k\sum_{j>k}f_{n,j,k}

每行不超过一个的总方案:

gi,jg_{i,j} 代表前 ii 行选了 jj 个的方案数,转移:

gi,j=gi1,j+gi1,j1×sumig_{i,j}=g_{i-1,j}+g_{i-1,j-1}\times sum_i

统计:i=1ngn,i\sum_{i=1}^{n} g_{n,i}

ans=i=1ngn,ij>kfn,j,kans = \sum_{i=1}^{n} g_{n,i}-\sum_{j>k}f_{n,j,k}

复杂度:O(n3m)O(n^3m)

优化:

由于我们统计答案时只关注 j,kj,k 之间的大小关系,所以我们记录当前列与其他列选的个数之差即可。

fi,jf_{i,j} 代表前 ii 行 当前列 - 其他列 所选 = jj 的方案数。

那么转移就是:

fi,j=fi1,j+fi1,j1×ai,cur+fi1,j+1×(sumiai,cur)f_{i,j}=f_{i-1,j}+f_{i-1,j-1}\times a_{i,cur} + f_{i-1,j+1}\times(sum_i-a_{i,cur})

由于 jj 可能为负数,所以我们统一把 ff 的第二维加上 nn,统计时也是如此。

初始化为 f0,0(f0,n)=0f_{0,0}(f_{0,n})=0,代表什么也不选的方案数为 1.

初始化 gi,0=1g_{i,0}=1 代表什么也不选方案数为 1.

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 205, M = 2005; // 由于第二维同时加了 n,所以 N 开两倍
const int MOD = 998244353;
typedef long long LL;

// 前 i 行在当前列选了 j 个,在其他列选了 k 个的方案数
// int f[N][N][N]; // 用于求不合法方案(\sum_{j > k} f_{n, j, k})
// f[i][j][k] = f[i - 1][j][k] + a[i][cur] * f[i - 1][j - 1][k] + (sum[i] - a[i][cur]) * f[i - 1][j][k - 1];

// 进一步优化,只需要记录当前列和其他列的差值即可。
// 前 i 行当前列选的比其他列多了 j 个的方案数
// 由于第二维可能为负数,所以同时加 n 
LL f[N][N];
// f[i][j] = f[i - 1][j] + a[i][cur] * f[i - 1][j - 1] + (sum[i] - a[i][cur]) * f[i - 1][j + 1];

// 用于求总方案(\sum_{i = 1}^{n}g_{n,i})
LL g[N][N]; // 前 i 行选了 j 个的方案数
// g[i][j] = g[i - 1][j] + sum[i] * g[i - 1][j - 1];

// ans = g - f;

int n, m;
int a[N][M];
LL sum[N];
LL ans;

int main()
{
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= m; j++)
        {
            scanf("%d", &a[i][j]);
            sum[i] = (sum[i] + a[i][j]) % MOD;
        }
    for(int cur = 1; cur <= m; cur++)
    {
        memset(f, 0, sizeof(f));
        f[0][n] = 1; // 什么都不选
        for(int i = 1; i <= n; i++)
            for(int j = n - i; j <= n + i; j++)
                f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * a[i][cur] % MOD + ((sum[i] - a[i][cur] + MOD) % MOD) * f[i - 1][j + 1] % MOD) % MOD;
        for(int j = 1; j <= n; j++)
            ans = (ans + f[n][n + j]) % MOD;
    }
    for(int i = 0; i <= n; i++) g[i][0] = 1;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
        {
            g[i][j] = (g[i - 1][j] + g[i - 1][j - 1] * sum[i] % MOD) % MOD;
        }
    LL alsum = 0;
    for(int i = 1; i <= n; i++)
        alsum = (alsum + g[n][i]) % MOD;
    cout << (alsum - ans + MOD) % MOD << endl;
    return 0;
}