题目大意
如果一个 N × N 的矩阵满足:
- 矩阵每行均为 [1, N] 的正整数的一个排列
- 矩阵内所有元素与其上方的元素不同
那么这个矩阵便是美丽的。
现给定一个 N × N 的美丽矩阵,求有多少个 N × N 的美丽矩阵比它小。(矩阵从上到下按行比较)
题目保证 N 不超过 2000
分析
这个题的切入点在于美丽矩阵的定义。如果我们把当前行看作待排序的一个序列,上面一行当成排序基准,则这个问题可以转化成错位排序问题。但是不同的是,在排了一部分数字以后,剩下的部分的排序标准就不那么严苛了(即存在一些可行数没有禁止位置)。
若 i 表示序列的长度, j 表示存在禁止位置的元素个数,则由容斥原理易得:
这个表达式非常优美,但是我们需要求 O(N2) 个 dp 值,如果直接计算的话需要 O(N3) 。不能承受。考虑到组合递推关系:
我们猜想 dp[i][j] 可以由 dp[i][j - 1] 和 dp[i - 1][j - 1] 推出。果然,我们有:
现在我们来解决这个问题。根据题目的定义,两个矩阵的比较与两个字符串的比较方式类似,如果 A 矩阵小于 B 矩阵,那么 A 矩阵的任意“前缀”小于等于 B 矩阵的对应“前缀”。如果两个矩阵的第一个不相同元素的位置为 (i, j) ,那么对于给定的 B 矩阵,这样的 A 矩阵共有
其中 way0 和 way1 分别表示有多少种选法使得 A[i][j] < B[i][j] 且是否选取 A[i - 1] 中在 j 位置以前出现过的元素; cnt 表示 A[i - 1] 的前 j 个元素与 A[i] 的前 (j - 1) 个元素的相同个数。
如果我们用树状数组或名次树来滑动地维护 way0 和 way1 ,则均摊时间复杂度可降为每个位置 O(logN) 。剪枝以后可以接受。
代码
总复杂度为 O(n2log(n))
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag,
tree_order_statistics_node_update>;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
const int maxn = 2123;
const ll MOD = 998244353;
ll fac[maxn], dp[maxn][maxn], ans, D[maxn];
int n, a[maxn][maxn];
pii way[maxn][maxn];
int main() {
scanf("%d", &n);
fac[0] = 1;
FOR(i, 1, n) fac[i] = fac[i - 1] * i % MOD;
dp[0][0] = 1;
FOR(i, 1, n) {
dp[i][0] = fac[i];
FOR(j, 1, i) {
dp[i][j] = (dp[i][j - 1] - dp[i - 1][j - 1]) % MOD;
if (dp[i][j] < 0) dp[i][j] += MOD;
}
}
D[0] = 1;
FOR(i, 1, n) D[i] = D[i - 1] * dp[n][n] % MOD;
FOR(i, 1, n) FOR(j, 1, n) scanf("%d", &a[i][j]);
FOR(i, 1, n) {
ordered_set<int> s[2];
FOR(j, 1, n) s[1].insert(j);
FOR(j, 1, n) {
way[i][j]._1 = s[0].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[0].find(a[i - 1][j]) != s[0].end())
way[i][j]._1--;
way[i][j]._2 = s[1].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[1].find(a[i - 1][j]) != s[1].end())
way[i][j]._2--;
s[0].erase(a[i][j]), s[1].erase(a[i][j]);
if (s[1].find(a[i - 1][j]) != s[1].end()) {
s[1].erase(a[i - 1][j]);
s[0].insert(a[i - 1][j]);
}
}
}
FOR(i, 1, n)
ans = (ans + way[1][i]._2 * fac[n - i] % MOD * D[n - 1]) % MOD;
FOR(i, 2, n) {
unordered_map<int, int> m;
FOR(j, 1, n) {
m[a[i - 1][j]]++;
int cnt = 2 * j - 1 - m.size();
ans = (ans + way[i][j]._1 * dp[n - j][n - 2 * j + cnt + 1]
% MOD * D[n - i]) % MOD;
ans = (ans + way[i][j]._2 * dp[n - j][n - 2 * j + cnt]
% MOD * D[n - i]) % MOD;
m[a[i][j]]++;
}
}
printf("%lld", ans);
}
网友评论