QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#670569 | #9476. 012 Grid | bulijiojiodibuliduo# | AC ✓ | 165ms | 23208kb | C++17 | 13.6kb | 2024-10-23 22:21:09 | 2024-10-23 22:21:09 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef vector<int> VI;
typedef basic_string<int> BI;
typedef long long ll;
typedef pair<int,int> PII;
typedef double db;
mt19937 mrand(random_device{}());
const ll mod=998244353;
int rnd(int x) { return mrand() % x;}
ll powmod(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
// head
template<int MOD, int RT> struct mint {
static const int mod = MOD;
static constexpr mint rt() { return RT; } // primitive root for FFT
int v; explicit operator int() const { return v; } // explicit -> don't silently convert to int
mint():v(0) {}
mint(ll _v) { v = int((-MOD < _v && _v < MOD) ? _v : _v % MOD);
if (v < 0) v += MOD; }
bool operator==(const mint& o) const {
return v == o.v; }
friend bool operator!=(const mint& a, const mint& b) {
return !(a == b); }
friend bool operator<(const mint& a, const mint& b) {
return a.v < b.v; }
mint& operator+=(const mint& o) {
if ((v += o.v) >= MOD) v -= MOD;
return *this; }
mint& operator-=(const mint& o) {
if ((v -= o.v) < 0) v += MOD;
return *this; }
mint& operator*=(const mint& o) {
v = int((ll)v*o.v%MOD); return *this; }
mint& operator/=(const mint& o) { return (*this) *= inv(o); }
friend mint pow(mint a, ll p) {
mint ans = 1; assert(p >= 0);
for (; p; p /= 2, a *= a) if (p&1) ans *= a;
return ans; }
friend mint inv(const mint& a) { assert(a.v != 0);
return pow(a,MOD-2); }
mint operator-() const { return mint(-v); }
mint& operator++() { return *this += 1; }
mint& operator--() { return *this -= 1; }
friend mint operator+(mint a, const mint& b) { return a += b; }
friend mint operator-(mint a, const mint& b) { return a -= b; }
friend mint operator*(mint a, const mint& b) { return a *= b; }
friend mint operator/(mint a, const mint& b) { return a /= b; }
};
const int MOD=998244353;
using mi = mint<MOD,5>; // 5 is primitive root for both common mods
namespace simp {
vector<mi> fac,ifac,invn;
void check(int x) {
if (fac.empty()) {
fac={mi(1),mi(1)};
ifac={mi(1),mi(1)};
invn={mi(0),mi(1)};
}
while (SZ(fac)<=x) {
int n=SZ(fac),m=SZ(fac)*2;
fac.resize(m);
ifac.resize(m);
invn.resize(m);
for (int i=n;i<m;i++) {
fac[i]=fac[i-1]*mi(i);
invn[i]=mi(MOD-MOD/i)*invn[MOD%i];
ifac[i]=ifac[i-1]*invn[i];
}
}
}
mi gfac(int x) {
if (x<0) return 0;
check(x); return fac[x];
}
mi ginv(int x) {
assert(x>0);
check(x); return invn[x];
}
mi gifac(int x) {
if (x<0) return 0;
check(x); return ifac[x];
}
mi binom(int n,int m) {
if (m < 0 || m > n) return mi(0);
return gfac(n)*gifac(m)*gifac(n - m);
}
mi binombf(int n,int m) {
if (m < 0 || m > n) return mi(0);
if (m>n-m) m=n-m;
mi p=1,q=1;
for (int i=1;i<=m;i++) p=p*(n+1-i),q=q*i;
return p/q;
}
}
int n,m;
mi way(int x1,int y1,int x2,int y2) {
int dx=x1-x2,dy=y2-y1;
return simp::binom(dx+dy,dx);
}
mi gao(int x1,int y1,int x2,int y2,int x3,int y3,int x4,int y4) {
return way(x1,y1,x2,y2)*way(x3,y3,x4,y4)-way(x1,y1,x4,y4)*way(x3,y3,x2,y2);
}
const int md = 998244353;
inline void add(int &x, int y) {
x += y;
if (x >= md) {
x -= md;
}
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) {
x += md;
}
}
inline int mul(int x, int y) {
return (long long) x * y % md;
}
inline int power(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) {
if (y & 1) {
res = mul(res, x);
}
}
return res;
}
inline int inv(int a) {
a %= md;
if (a < 0) {
a += md;
}
int b = md, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a;
swap(a, b);
u -= t * v;
swap(u, v);
}
if (u < 0) {
u += md;
}
return u;
}
namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};
void init() {
int temp = md - 1;
max_base = 0;
while (temp % 2 == 0) {
temp >>= 1;
++max_base;
}
root = 2;
while (true) {
if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
break;
}
++root;
}
}
void ensure_base(int nbase) {
if (max_base == -1) {
init();
}
if (nbase <= base) {
return;
}
assert(nbase <= max_base);
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
}
roots.resize(1 << nbase);
while (base < nbase) {
int z = power(root, 1 << (max_base - 1 - base));
for (int i = 1 << (base - 1); i < 1 << base; ++i) {
roots[i << 1] = roots[i];
roots[i << 1 | 1] = mul(roots[i], z);
}
++base;
}
}
void dft(vector<int> &a) {
int n = a.size(), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = base - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
a[j + k] = (x + y) % md;
a[j + k + i] = (x + md - y) % md;
}
}
}
}
vector<int> multiply(vector<int> a, vector<int> b) {
int need = a.size() + b.size() - 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
bool equal = a == b;
dft(a);
if (equal) {
b = a;
} else {
dft(b);
}
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(a[i], b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(need);
return a;
}
vector<int> inverse(vector<int> a) {
int n = a.size(), m = (n + 1) >> 1;
if (n == 1) {
return vector<int>(1, inv(a[0]));
} else {
vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
int need = n << 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
dft(a);
dft(b);
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(n);
return a;
}
}
}
using ntt::multiply;
using ntt::inverse;
vector<int>& operator += (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
add(a[i], b[i]);
}
return a;
}
vector<int> operator + (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c += b;
}
vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
sub(a[i], b[i]);
}
return a;
}
vector<int> operator - (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c -= b;
}
vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
if (min(a.size(), b.size()) < 128) {
vector<int> c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < c.size(); ++i) {
for (int j = 0; j < b.size(); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
} else {
a = multiply(a, b);
}
return a;
}
vector<int> operator * (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c *= b;
}
vector<int>& operator /= (vector<int> &a, const vector<int> &b) {
int n = a.size(), m = b.size();
if (n < m) {
a.clear();
} else {
vector<int> c = b;
reverse(a.begin(), a.end());
reverse(c.begin(), c.end());
c.resize(n - m + 1);
a *= inverse(c);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
vector<int> operator / (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c /= b;
}
vector<int>& operator %= (vector<int> &a, const vector<int> &b) {
int n = a.size(), m = b.size();
if (n >= m) {
vector<int> c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; ++i) {
sub(a[i], c[i]);
}
}
return a;
}
vector<int> operator % (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c %= b;
}
vector<int> derivative(const vector<int> &a) {
int n = a.size();
vector<int> b(n - 1);
for (int i = 1; i < n; ++i) {
b[i - 1] = mul(a[i], i);
}
return b;
}
vector<int> primitive(const vector<int> &a) {
int n = a.size();
vector<int> b(n + 1), invs(n + 1);
for (int i = 1; i <= n; ++i) {
invs[i] = i == 1 ? 1 : mul(md - md / i, invs[md % i]);
b[i] = mul(a[i - 1], invs[i]);
}
return b;
}
vector<int> logarithm(const vector<int> &a) {
vector<int> b = primitive(derivative(a) * inverse(a));
b.resize(a.size());
return b;
}
vector<int> exponent(const vector<int> &a) {
vector<int> b(1, 1);
while (b.size() < a.size()) {
vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
add(c[0], 1);
vector<int> old_b = b;
b.resize(b.size() << 1);
c -= logarithm(b);
c *= old_b;
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = c[i];
}
}
b.resize(a.size());
return b;
}
vector<int> power(vector<int> a, int m) {
int n = a.size(), p = -1;
vector<int> b(n);
for (int i = 0; i < n; ++i) {
if (a[i]) {
p = i;
break;
}
}
if (p == -1) {
b[0] = !m;
return b;
}
if ((long long) m * p >= n) {
return b;
}
int mu = power(a[p], m), di = inv(a[p]);
vector<int> c(n - m * p);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(a[i + p], di);
}
c = logarithm(c);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(c[i], m);
}
c = exponent(c);
for (int i = 0; i < n - m * p; ++i) {
b[i + m * p] = mul(c[i], mu);
}
return b;
}
vector<int> sqrt(const vector<int> &a) {
vector<int> b(1, 1);
while (b.size() < a.size()) {
vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
vector<int> old_b = b;
b.resize(b.size() << 1);
c *= inverse(b);
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = mul(c[i], (md + 1) >> 1);
}
}
b.resize(a.size());
return b;
}
vector<int> multiply_all(int l, int r, vector<vector<int>> &all) {
if (l > r) {
return vector<int>();
} else if (l == r) {
return all[l];
} else {
int y = (l + r) >> 1;
return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
}
}
vector<int> evaluate(const vector<int> &f, const vector<int> &x) {
int n = x.size();
if (!n) {
return vector<int>();
}
vector<vector<int>> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = vector<int>{(md - x[i]) % md, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
vector<vector<int>> down(n * 2);
down[1] = f % up[1];
for (int i = 2; i < n * 2; ++i) {
down[i] = down[i >> 1] % up[i];
}
vector<int> y(n);
for (int i = 0; i < n; ++i) {
y[i] = down[i + n][0];
}
return y;
}
vector<int> interpolate(vector<int> x, const vector<int> &y) {
int n = x.size();
vector<vector<int>> up(n * 2);
for (int i = 0; i < n; ++i) {
x[i] %= md;
up[i + n] = vector<int>{(md - x[i]) % md, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
vector<int> a = evaluate(derivative(up[1]), x);
for (int i = 0; i < n; ++i) {
a[i] = mul(y[i], inv(a[i]));
}
vector<vector<int>> down(n * 2);
for (int i = 0; i < n; ++i) {
down[i + n] = vector<int>(1, a[i]);
}
for (int i = n - 1; i; --i) {
down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
}
return down[1];
}
mi solve00(int n,int m) {
mi ans=0;
VI f(n-1),g(m-1);
rep(i,0,n-1) {
f[i]=(int)(simp::gifac(i)*simp::gifac(i));
}
rep(i,0,m-1) {
g[i]=(int)(simp::gifac(i)*simp::gifac(i));
}
auto h=f*g;
for (int i=0;i<SZ(h);i++) ans+=h[i]*simp::gfac(i)*simp::gfac(i);
rep(i,0,n-1) {
f[i]=(int)(simp::gifac(i-1)*simp::gifac(i+1));
}
rep(i,0,m-1) {
g[i]=(int)(simp::gifac(i-1)*simp::gifac(i+1));
}
h=f*g;
for (int i=0;i<SZ(h);i++) ans-=h[i]*simp::gfac(i)*simp::gfac(i);
ans=ans-way(n,1,1,m)+1;
return ans;
}
mi solve02(int n,int m) {
mi ans=0;
for (int d=1;d<=m-2;d++) {
ans+=(m-d-1)*(simp::binom(n+d-2,n-1)*simp::binom(n+d-2,n-1)-
simp::binom(n+d-2,n-2)*simp::binom(n+d-2,n));
}
return ans;
}
mi solve01(int n,int m) {
mi ans=0;
for (int i=1;i<m;i++) {
ans+=gao(n-1,i,0,m-1,n,i+1,1,m)-way(n-1,i,0,m-1);
}
return ans;
}
int main() {
scanf("%d%d",&n,&m);
if (n==1&&m==1) {
printf("%d\n",3);
return 0;
}
mi ans=3;
ans+=2*(simp::binom(n+m,n)-2);
// solve11
ans+=simp::binom(n+m-2,n-1)*simp::binom(n+m-2,n-1)-
simp::binom(n+m-2,n)*simp::binom(n+m-2,n-2);
ans=ans-simp::binom(n+m-2,n-1)*2+1;
//printf("!! %d\n",(int)ans);
// solve00, solve22
ans+=2*solve00(n,m);
// solve02, solve20
ans+=solve02(n,m)+solve02(m,n);
// solve10
ans+=2*(solve01(n,m)+solve01(m,n));
printf("%d\n",(int)ans);
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 0ms
memory: 4100kb
input:
2 2
output:
11
result:
ok "11"
Test #2:
score: 0
Accepted
time: 0ms
memory: 3828kb
input:
20 23
output:
521442928
result:
ok "521442928"
Test #3:
score: 0
Accepted
time: 134ms
memory: 23108kb
input:
200000 200000
output:
411160917
result:
ok "411160917"
Test #4:
score: 0
Accepted
time: 0ms
memory: 4104kb
input:
8 3
output:
2899
result:
ok "2899"
Test #5:
score: 0
Accepted
time: 0ms
memory: 3828kb
input:
10 9
output:
338037463
result:
ok "338037463"
Test #6:
score: 0
Accepted
time: 0ms
memory: 4088kb
input:
3 3
output:
64
result:
ok "64"
Test #7:
score: 0
Accepted
time: 0ms
memory: 3828kb
input:
9 4
output:
39733
result:
ok "39733"
Test #8:
score: 0
Accepted
time: 0ms
memory: 3896kb
input:
36 33
output:
545587245
result:
ok "545587245"
Test #9:
score: 0
Accepted
time: 0ms
memory: 3828kb
input:
35 39
output:
62117944
result:
ok "62117944"
Test #10:
score: 0
Accepted
time: 0ms
memory: 3804kb
input:
48 10
output:
264659761
result:
ok "264659761"
Test #11:
score: 0
Accepted
time: 0ms
memory: 3836kb
input:
46 30
output:
880000821
result:
ok "880000821"
Test #12:
score: 0
Accepted
time: 0ms
memory: 3808kb
input:
25 24
output:
280799864
result:
ok "280799864"
Test #13:
score: 0
Accepted
time: 0ms
memory: 3800kb
input:
17 10
output:
624958192
result:
ok "624958192"
Test #14:
score: 0
Accepted
time: 5ms
memory: 3964kb
input:
4608 9241
output:
322218996
result:
ok "322218996"
Test #15:
score: 0
Accepted
time: 4ms
memory: 4192kb
input:
3665 6137
output:
537704652
result:
ok "537704652"
Test #16:
score: 0
Accepted
time: 5ms
memory: 4208kb
input:
4192 6186
output:
971816887
result:
ok "971816887"
Test #17:
score: 0
Accepted
time: 4ms
memory: 4448kb
input:
4562 4403
output:
867628411
result:
ok "867628411"
Test #18:
score: 0
Accepted
time: 5ms
memory: 3988kb
input:
8726 3237
output:
808804305
result:
ok "808804305"
Test #19:
score: 0
Accepted
time: 5ms
memory: 3968kb
input:
5257 8166
output:
488829288
result:
ok "488829288"
Test #20:
score: 0
Accepted
time: 5ms
memory: 4036kb
input:
8013 7958
output:
215666893
result:
ok "215666893"
Test #21:
score: 0
Accepted
time: 5ms
memory: 3924kb
input:
8837 5868
output:
239261227
result:
ok "239261227"
Test #22:
score: 0
Accepted
time: 5ms
memory: 3948kb
input:
8917 5492
output:
706653412
result:
ok "706653412"
Test #23:
score: 0
Accepted
time: 5ms
memory: 4248kb
input:
9628 5378
output:
753685501
result:
ok "753685501"
Test #24:
score: 0
Accepted
time: 154ms
memory: 22348kb
input:
163762 183794
output:
141157510
result:
ok "141157510"
Test #25:
score: 0
Accepted
time: 71ms
memory: 12784kb
input:
83512 82743
output:
114622013
result:
ok "114622013"
Test #26:
score: 0
Accepted
time: 76ms
memory: 12312kb
input:
84692 56473
output:
263907717
result:
ok "263907717"
Test #27:
score: 0
Accepted
time: 38ms
memory: 8016kb
input:
31827 74195
output:
200356808
result:
ok "200356808"
Test #28:
score: 0
Accepted
time: 159ms
memory: 22456kb
input:
189921 163932
output:
845151158
result:
ok "845151158"
Test #29:
score: 0
Accepted
time: 72ms
memory: 12932kb
input:
27157 177990
output:
847356039
result:
ok "847356039"
Test #30:
score: 0
Accepted
time: 73ms
memory: 12968kb
input:
136835 39390
output:
962822638
result:
ok "962822638"
Test #31:
score: 0
Accepted
time: 63ms
memory: 12452kb
input:
118610 18795
output:
243423874
result:
ok "243423874"
Test #32:
score: 0
Accepted
time: 75ms
memory: 12588kb
input:
122070 19995
output:
531055604
result:
ok "531055604"
Test #33:
score: 0
Accepted
time: 78ms
memory: 13096kb
input:
20031 195670
output:
483162363
result:
ok "483162363"
Test #34:
score: 0
Accepted
time: 132ms
memory: 23036kb
input:
199992 199992
output:
262099623
result:
ok "262099623"
Test #35:
score: 0
Accepted
time: 165ms
memory: 23088kb
input:
200000 199992
output:
477266520
result:
ok "477266520"
Test #36:
score: 0
Accepted
time: 162ms
memory: 23208kb
input:
199999 199996
output:
165483205
result:
ok "165483205"
Test #37:
score: 0
Accepted
time: 0ms
memory: 3884kb
input:
1 1
output:
3
result:
ok "3"
Test #38:
score: 0
Accepted
time: 14ms
memory: 6268kb
input:
1 100000
output:
8828237
result:
ok "8828237"
Test #39:
score: 0
Accepted
time: 12ms
memory: 6220kb
input:
100000 2
output:
263711286
result:
ok "263711286"
Test #40:
score: 0
Accepted
time: 0ms
memory: 3812kb
input:
50 50
output:
634767411
result:
ok "634767411"
Extra Test:
score: 0
Extra Test Passed