QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#619987 | #2834. Nonsense | emsger | RE | 0ms | 0kb | C++20 | 3.5kb | 2024-10-07 16:13:46 | 2024-10-07 16:13:51 |
answer
#include <algorithm>
#include <cassert>
#include <iostream>
#include <vector>
using i64 = long long;
using u32 = unsigned int;
using u64 = unsigned long long;
void set_io(std::string name)
{
#ifndef NO_FREOPEN
freopen((name + ".in").c_str(), "r", stdin);
freopen((name + ".out").c_str(), "w", stdout);
#endif
std::cin.tie(nullptr);
std::ios::sync_with_stdio(false);
}
template <u32 M>
struct static_modint
{
u32 v;
static_modint() : v(0) {}
static_modint(i64 x) : v((x %= M) < 0 ? x + M : x) {}
u32 val() const { return v; }
using mint = static_modint;
mint &operator+=(mint x)
{
if ((v += x.v) >= M) v -= M;
return *this;
}
mint &operator-=(mint x)
{
if ((v += M - x.v) >= M) v -= M;
return *this;
}
mint &operator*=(mint x)
{
v = (u64)v * x.v % M;
return *this;
}
friend mint operator+(mint a, mint b) { return a += b; }
friend mint operator-(mint a, mint b) { return a -= b; }
friend mint operator*(mint a, mint b) { return a *= b; }
mint pow(u64 x) const
{
mint res = 1;
mint base = *this;
while (x) {
if (x & 1) res *= base;
base *= base;
x >>= 1;
}
return res;
}
mint inv() const
{
return pow(M - 2);
}
};
using mint = static_modint<998'244'353>;
void solve_x_eq_y(int n, mint x, int maxa, int maxb, const std::vector<std::pair<int, int>> &q)
{
int maxab = maxa + maxb + 1;
std::vector<mint> decr(maxab + 1);
decr[0] = 1;
for (int i = 0; i < maxab; i++) decr[i + 1] = decr[i] * (n + 1 - i);
std::vector<mint> fac(maxab + 1), ifac(maxab + 1);
fac[0] = 1;
for (int i = 1; i <= maxab; i++) fac[i] = fac[i - 1] * i;
ifac[maxab] = fac[maxab].inv();
for (int i = maxab; i >= 1; i--) ifac[i - 1] = ifac[i] * i;
for (auto [a, b] : q) {
std::cout << (decr[a + b + 1] * ifac[a + b + 1] * x.pow(n - a - b)).val() << std::endl;
}
}
void solve_x_neq_y(int n, mint x, mint y, int maxa, int maxb, const std::vector<std::pair<int, int>> &q)
{
std::vector f(maxa + 1, std::vector<mint>(maxb + 1));
int maxab = maxa + maxb;
std::vector<mint> decr(maxab + 1);
decr[0] = 1;
for (int i = 0; i < maxab; i++) decr[i + 1] = decr[i] * (n + 1 - i);
std::vector<mint> fac(maxab + 1), ifac(maxab + 1);
fac[0] = 1;
for (int i = 1; i <= maxab; i++) fac[i] = fac[i - 1] * i;
ifac[maxab] = fac[maxab].inv();
for (int i = maxab; i >= 1; i--) ifac[i - 1] = ifac[i] * i;
mint ixy = (x - y).inv();
f[0][0] = (x.pow(n + 1) - y.pow(n + 1)) * ixy;
for (int a = 1; a <= maxa; a++) f[a][0] = (decr[a] * ifac[a] * x.pow(n + 1 - a) - f[a - 1][0]) * ixy;
for (int b = 1; b <= maxb; b++) f[0][b] = (f[0][b - 1] - decr[b] * ifac[b] * y.pow(n + 1 - b)) * ixy;
for (int a = 1; a <= maxa; a++) {
for (int b = 1; b <= maxb; b++) {
f[a][b] = (f[a][b - 1] - f[a - 1][b]) * ixy;
}
}
for (auto [a, b] : q) {
std::cout << f[a][b].val() << std::endl;
}
}
int main()
{
set_io("count");
int n, x, y, m;
while (std::cin >> n >> x >> y >> m) {
std::vector<std::pair<int, int>> q;
int maxa = 0, maxb = 0;
for (int i = 0; i < m; i++) {
int a, b;
std::cin >> a >> b;
maxa = std::max(maxa, a);
maxb = std::max(maxb, b);
q.emplace_back(a, b);
}
if (x == y) {
solve_x_eq_y(n, x, maxa, maxb, q);
} else {
solve_x_neq_y(n, x, y, maxa, maxb, q);
}
}
}
詳細信息
Test #1:
score: 0
Dangerous Syscalls
input:
3 1 2 2 1 1 1 2 100 2 3 1 1 1