QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#924895 | #9572. Bingo | Hiraethsoul | WA | 0ms | 3712kb | C++20 | 14.1kb | 2025-03-03 19:34:16 | 2025-03-03 19:34:17 |
Judging History
answer
#include <bits/stdc++.h>
#define int long long
const int P = 998244353; // 如果wa有可能需要换mod数
class Polynomial : public std::vector<int>
{
static int multip(int a, int b)
{
return a * b % P;
}
// static int multip(int a, int b)
// {
// int res = a * b - (int)(1.L * a * b / P) * P;
// res %= P;
// if (res < 0)
// {
// res += P;
// }
// return res;
// }
static int add(int a, int b)
{
a += b;
a -= (a >= P ? P : 0);
return a;
}
static int sub(int a, int b)
{
a -= b;
a += (a < 0 ? P : 0);
return a;
}
static int qmi(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1)
{
res = res * a % P;
}
a = a * a % P;
b >>= 1;
}
return res;
}
static std::vector<int> w;
static void initW(int _log)
{
const int r = 1 << _log;
if (w.size() >= r)
{
return;
}
w.assign(r, 0);
w[r >> 1] = 1;
int s = qmi(3, (P - 1) >> _log); // 3是原根
for (int i = r / 2 + 1; i < r; i++)
{
w[i] = w[i - 1] * s % P;
}
for (int i = r / 2 - 1; i > 0; i--)
{
w[i] = w[i * 2];
}
}
public:
using std::vector<int>::vector;
friend void dft(Polynomial &a)
{
const int n = a.size();
for (int k = n >> 1; k; k >>= 1)
{
for (int i = 0; i < n; i += k << 1)
{
for (int j = 0; j < k; j++)
{
int v = a[i + j + k];
a[i + j + k] = multip(sub(a[i + j], v), w[k + j]);
a[i + j] = add(a[i + j], v);
}
}
}
}
friend void idft(Polynomial &a)
{
const int n = a.size();
for (int k = 1; k < n; k <<= 1)
{
for (int i = 0; i < n; i += k << 1)
{
for (int j = 0; j < k; j++)
{
int u = a[i + j];
int v = multip(a[i + j + k], w[j + k]);
a[i + j + k] = sub(u, v);
a[i + j] = add(u, v);
}
}
}
int val = P - (P - 1) / n;
for (int i = 0; i < n; i++)
{
a[i] = multip(a[i], val);
}
std::reverse(a.begin() + 1, a.end());
}
friend Polynomial operator*(Polynomial a, Polynomial b)
{
if (a.size() == 0 or b.size() == 0)
{
return Polynomial();
}
int n = a.size() + b.size() - 1;
int _log = std::__lg(2 * n - 1);
int s = 1 << _log;
if (((P - 1) & (s - 1)) != 0 or std::min(a.size(), b.size()) < 128)
{
Polynomial res(n);
for (int i = 0; i < a.size(); i++)
{
for (int j = 0; j < b.size(); j++)
{
res[i + j] = add(res[i + j], multip(a[i], b[j]));
}
}
return res;
}
initW(_log);
a.resize(s), b.resize(s);
dft(a), dft(b);
for (int i = 0; i < s; i++)
{
a[i] = multip(a[i], b[i]);
}
idft(a);
return a.resize(n), a;
}
friend Polynomial deriv(const Polynomial &a) // 求导
{
int n = a.size();
if (n <= 1)
{
return Polynomial();
}
Polynomial p(n - 1);
for (int i = 1; i < n; i++)
{
p[i - 1] = multip(i, a[i]);
}
return p;
}
friend Polynomial integr(const Polynomial &a)
{
int n = a.size();
Polynomial p(n + 1);
std::vector<int> _inv(n + 1);
_inv[1] = 1;
for (int i = 2; i <= n; i++)
{
_inv[i] = multip(_inv[P % i], (P - P / i));
}
for (int i = 0; i < n; ++i)
{
p[i + 1] = multip(a[i], _inv[i + 1]);
}
return p;
}
friend Polynomial inv(const Polynomial &a)
{
int n = a.size();
if (n == 1)
{
return {qmi(a[0], P - 2)};
}
Polynomial half(a.begin(), a.begin() + (n + 1) / 2);
Polynomial b = inv(half), c = a * b;
for (auto &x : c)
{
x = (x == 0 ? 0 : P - x); // ?
}
c[0] = add(c[0], 2);
c = c * b;
c.resize(n);
return c;
}
friend Polynomial ln(const Polynomial &a)
{
int n = a.size();
Polynomial b(n, 0);
for (int i = 1; i < n; i++)
{
b[i - 1] = multip(i, a[i]);
}
b = b * inv(a);
b.resize(n);
std::vector<int> _inv(n);
_inv[1] = 1;
for (int i = 2; i < n; i++)
{
_inv[i] = multip(P - P / i, _inv[P % i]);
}
for (int i = n - 1; i; i--)
{
b[i] = multip(b[i - 1], _inv[i]);
}
b[0] = 0;
return b;
}
friend Polynomial exp(const Polynomial &a)
{
int n = a.size();
if (n == 1)
{
return {1};
}
Polynomial half(a.begin(), a.begin() + (n + 1) / 2);
Polynomial b = exp(half);
b.resize(n);
Polynomial c = ln(b);
for (int i = 0; i < n; i++)
{
c[i] = sub(a[i], c[i]);
}
c[0] = add(c[0], 1);
c = c * b;
c.resize(n);
return c;
}
friend Polynomial power(Polynomial &F, std::string s)
{
int k1 = 0;
int k2 = 0;
int k3 = 0;
for (int i = 0; i < s.length(); ++i)
{
k1 = (k1 * 10 + s[i] - '0') % P;
k2 = (k2 * 10 + s[i] - '0') % (P - 1);
if (i < 7)
{
k3 = k3 * 10 + s[i] - '0';
}
}
int n = F.size();
if (!F[0] and k3 >= F.size())
{
F.assign(n, 0);
return F;
}
int pos = 0; // 处理移位
for (int i = 0; i < n; ++i)
{
if (F[i])
{
pos = i;
break;
}
}
if (pos)
{
for (int i = pos; i < n; ++i)
{
F[i - pos] = F[i];
F[i] = 0;
}
}
int val = F[0];
int cur = qmi(val, P - 2);
for (int i = 0; i < F.size(); ++i)
{
F[i] = F[i] * cur % P;
}
F = ln(F);
for (int i = 0; i < F.size(); ++i)
{
F[i] = F[i] * k1 % P;
}
F = exp(F);
cur = qmi(val, k2);
for (int i = 0; i < F.size(); ++i)
{
F[i] = F[i] * cur % P;
}
if (pos)
{
pos = std::min(1ll * pos * k1, n);
for (int i = n - 1; i >= 0; --i)
{
if (i + pos < n)
{
F[i + pos] = F[i];
}
F[i] = 0;
}
}
return F;
}
friend Polynomial power(const Polynomial &F, int b, int m) // m是目标的区间长度
{
Polynomial res = {1};
Polynomial G = F;
while (b)
{
if (b & 1)
{
res = res * G;
if (res.size() > m)
{
res.resize(m);
}
}
G = G * G;
if (G.size() > m)
{
G.resize(m);
}
b >>= 1;
}
return res;
}
};
std::vector<int> Polynomial::w;
using Poly = Polynomial;
int qmi(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1)
{
res = res * a % P;
}
a = a * a % P;
b >>= 1;
}
return res;
}
using i64 = long long;
template <class T>
constexpr T power(T a, i64 b)
{
T res{1};
for (; b; b /= 2, a *= a)
{
if (b % 2)
{
res *= a;
}
}
return res;
}
constexpr i64 mul(i64 a, i64 b, i64 p)
{
i64 res = a * b - i64(1.L * a * b / p) * p;
res %= p;
if (res < 0)
{
res += p;
}
return res;
}
template <i64 P>
struct MInt
{
i64 x;
constexpr MInt() : x{0} {}
constexpr MInt(i64 x) : x{norm(x % getMod())} {}
static i64 Mod;
constexpr static i64 getMod()
{
if (P > 0)
{
return P;
}
else
{
return Mod;
}
}
constexpr static void setMod(i64 Mod_)
{
Mod = Mod_;
}
constexpr i64 norm(i64 x) const
{
if (x < 0)
{
x += getMod();
}
if (x >= getMod())
{
x -= getMod();
}
return x;
}
constexpr i64 val() const
{
return x;
}
constexpr MInt operator-() const
{
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const
{
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) &
{
if (getMod() < (1ULL << 31))
{
x = x * rhs.x % (getMod());
}
else
{
x = mul(x, rhs.x, getMod());
}
return *this;
}
constexpr MInt &operator+=(MInt rhs) &
{
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) &
{
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) &
{
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs)
{
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs)
{
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs)
{
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs)
{
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MInt &a)
{
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a)
{
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs)
{
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs)
{
return lhs.val() != rhs.val();
}
friend constexpr bool operator<(MInt lhs, MInt rhs)
{
return lhs.val() < rhs.val();
}
};
template <>
i64 MInt<0>::Mod = 998244353;
using Z = MInt<P>;
struct Comb
{
int n;
std::vector<Z> _fac;
std::vector<Z> _invfac;
std::vector<Z> _inv;
Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
Comb(int n) : Comb()
{
init(n);
}
void init(int m)
{
m = std::min<int>(m, Z::getMod() - 1);
if (m <= n)
return;
_fac.resize(m + 1);
_invfac.resize(m + 1);
_inv.resize(m + 1);
for (int i = n + 1; i <= m; i++)
{
_fac[i] = _fac[i - 1] * i;
}
_invfac[m] = _fac[m].inv();
for (int i = m; i > n; i--)
{
_invfac[i - 1] = _invfac[i] * i;
_inv[i] = _invfac[i] * _fac[i - 1];
}
n = m;
}
Z fac(int m)
{
if (m > n)
init(2 * m);
return _fac[m];
}
Z invfac(int m)
{
if (m > n)
init(2 * m);
return _invfac[m];
}
Z inv(int m)
{
if (m > n)
init(2 * m);
return _inv[m];
}
Z A(int a, int b)
{
if (a < b or b < 0)
return 0;
return fac(a) * invfac(a - b);
}
Z C(int n, int m)
{
if (n < m || m < 0)
return 0;
return fac(n) * invfac(m) * invfac(n - m);
}
} comb;
void solve()
{
int n, m;
std::cin >> n >> m;
std::vector<int> a(n * m + 1, 0);
for (int i = 1; i <= n * m; ++i)
{
std::cin >> a[i];
}
std::sort(begin(a) + 1, end(a));
std::vector<int> fac(n * m + 1, 0);
std::vector<int> ifac(n * m + 1, 0);
fac[0] = 1;
for (int i = 1; i <= n * m; ++i)
{
fac[i] = fac[i - 1] * i % P;
}
ifac[n] = qmi(fac[n], P - 2);
for (int i = n; i >= 1; i--)
{
ifac[i - 1] = i * ifac[i] % P;
}
Poly f(n * m + 1, 0);
for (int i = 1; i <= n * m; ++i)
{
f[i] = a[i] * fac[i - 1] % P;
}
Poly g(n * m + 1, 0);
for (int i = 0; i <= n * m; ++i)
{
g[i] = ifac[n * m - i];
}
f = f * g;
std::vector<Z> c(n * m + 1, 0);
for (int i = 0; i <= n * m; ++i)
{
c[i] = (Z)(fac[n * m - i] * i % P * f[n * m + i] % P);
}
Z ans = 0;
for (int x = 0; x <= n; ++x)
{
for (int y = 0; y <= m; ++y)
{
int C = x * m + y * n - x * y;
if ((x + y) & 1)
{
ans += comb.C(n, x) * comb.C(m, y) * c[C];
}
else
{
ans -= comb.C(n, x) * comb.C(m, y) * c[C];
}
}
}
std::cout << ans << '\n';
}
signed main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0);
int T;
std::cin >> T;
while (T--)
{
solve();
}
}
詳細信息
Test #1:
score: 0
Wrong Answer
time: 0ms
memory: 3712kb
input:
4 2 2 1 3 2 4 3 1 10 10 10 1 3 20 10 30 3 4 1 1 4 5 1 4 1 9 1 9 8 10
output:
56 60 998244233 356425214
result:
wrong answer 3rd numbers differ - expected: '60', found: '998244233'