QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#508453 | #7632. Balanced Arrays | pandapythoner | AC ✓ | 1030ms | 88692kb | C++23 | 15.4kb | 2024-08-07 15:44:20 | 2024-08-07 15:44:22 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define flt double
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define rep(i, n) for(int i = 0; i < n; i += 1)
#define len(a) ((int)(a).size())
mt19937 rnd(234);
const ll mod = 998244353;
ll bin_pow(ll x, ll n) {
ll rs = 1;
for (ll i = 1, a = x; i <= n; i *= 2, a = a * a % mod)
if (n & i) rs = rs * a % mod;
return rs;
}
ll inv(ll x) {
return bin_pow(x, mod - 2);
}
namespace fft {
int mxpw;
int mxn;
ll w;
void build_w() {
ll phi = mod - 1;
ll f = phi;
vector<ll> p;
for (ll i = 2; i * i <= f; i += 1) {
if (f % i == 0) {
p.push_back(i);
while (f % i == 0) {
f /= i;
}
}
}
if (f > 1) {
p.push_back(f);
}
for (int i = 1; i < mod; i += 1) {
bool ok = true;
for (auto q : p) {
if (bin_pow(i, phi / q) == 1) {
ok = false;
break;
}
}
if (ok) {
w = bin_pow(i, phi / (1 << mxpw));
break;
}
}
}
vector<ll> rvx;
void build_rvx(int n) {
rvx.resize(n + 1);
for (int i = 1; i <= n; i += 1) {
rvx[i] = inv(i);
}
}
vector<ll> rvi, wpws;
void build(int _mxpw) {
mxpw = _mxpw;
mxn = (1 << mxpw);
build_w();
int n = (1 << mxpw);
rvi.resize(n);
rvi[0] = 0;
for (int i = 1; i < n; i += 1) {
rvi[i] = (rvi[i >> 1] >> 1);
if (i & 1) {
rvi[i] += (1 << (mxpw - 1));
}
}
wpws.resize(n + 1);
wpws[0] = 1;
for (int i = 1; i <= n; i += 1) {
wpws[i] = (wpws[i - 1] * w) % mod;
}
build_rvx(mxn);
}
void fft(vector<ll>& a, int nk) {
int n = (1 << nk);
for (int i = 0; i < n; i += 1) {
int mrv = (rvi[i] >> (mxpw - nk));
if (mrv < i) {
swap(a[mrv], a[i]);
}
}
for (int ln = 1; ln < n; ln *= 2) {
int ln2 = ln + ln;
for (int i = 0; i < n; i += ln2) {
for (int j = 0; j < ln; j += 1) {
ll mw = wpws[mxn / ln2 * j];
int u = i + j;
int v = u + ln;
ll y = a[v] * mw % mod;
a[v] = a[u] - y;
if (a[v] < 0) {
a[v] += mod;
}
a[u] += y;
if (a[u] >= mod) {
a[u] -= mod;
}
}
}
}
}
void rev_fft(vector<ll>& a, int nk) {
int n = (1 << nk);
fft(a, nk);
ll rvn = inv(n);
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; i += 1) {
a[i] = (a[i] * rvn) % mod;
}
}
vector<ll> square(vector<ll> a) {
int nk = 0;
while ((1 << nk) < (int)a.size() + (int)a.size() - 1) {
nk += 1;
}
int n = (1 << nk);
a.resize(n, 0);
fft(a, nk);
for (int i = 0; i < n; i += 1) {
a[i] = (a[i] * a[i]) % mod;
}
rev_fft(a, nk);
while (!a.empty() && a.back() == 0) {
a.pop_back();
}
return a;
}
vector<ll> mul(vector<ll> a, vector<ll> b) {
int nk = 0;
while ((1 << nk) < (int)a.size() + (int)b.size() - 1) {
nk += 1;
}
int n = (1 << nk);
a.resize(n, 0);
b.resize(n, 0);
fft(a, nk);
fft(b, nk);
for (int i = 0; i < n; i += 1) {
a[i] = (a[i] * b[i]) % mod;
}
rev_fft(a, nk);
while (!a.empty() && a.back() == 0) {
a.pop_back();
}
return a;
}
void add_inplace(vector<ll>& a, const vector<ll>& b, ll k = 1) {
a.resize(max(a.size(), b.size()), 0);
for (int i = 0; i < (int)b.size(); i += 1) {
a[i] = (a[i] + b[i] * k) % mod;
}
}
vector<ll> add(vector<ll> a, const vector<ll>& b, ll k = 1) {
a.resize(max(a.size(), b.size()), 0);
for (int i = 0; i < (int)b.size(); i += 1) {
a[i] = (a[i] + b[i] * k) % mod;
}
return a;
}
vector<ll> sub(vector<ll> a, const vector<ll>& b, ll k = 1) {
a.resize(max(a.size(), b.size()), 0);
for (int i = 0; i < (int)b.size(); i += 1) {
a[i] = (a[i] + mod - b[i] * k % mod) % mod;
}
return a;
}
vector<ll> replace_x_slow(vector<ll>& a, const vector<ll>& b) {
vector<ll> rs = {};
vector<ll> bpw = { 1 };
for (int i = 0; i < (int)a.size(); i += 1) {
if (i > 0) {
bpw = mul(bpw, b);
}
add_inplace(rs, bpw, a[i]);
}
return rs;
}
vector<ll> replace_x(vector<ll>& a, const vector<ll>& b) {
vector<ll> rs = {};
vector<ll> bpw = b;
int n = a.size();
vector<vector<ll>> d(n);
for (int i = 0; i < n; i += 1) {
d[i] = { a[i] };
}
while (n > 1) {
int m = (n + 1) / 2;
vector<vector<ll>> nd(m);
for (int i = 0; i < n; i += 1) {
if (i % 2 == 0) {
nd[i / 2] = d[i];
} else {
add_inplace(nd[i / 2], mul(d[i], bpw));
}
}
n = m;
d.swap(nd);
if (n != 1) {
bpw = square(bpw);
}
}
return d[0];
}
vector<ll> shift_x(vector<ll> a, ll t) {
if (a.empty()) {
return {};
}
int n = (int)a.size() - 1;
vector<ll> f(n + 1), rf(n + 1);
f[0] = rf[0] = 1;
for (int i = 1; i <= n; i += 1) {
f[i] = (f[i - 1] * i) % mod;
rf[i] = inv(f[i]);
}
vector<ll> b(n + 1), c(n + 1);
ll tpw = 1;
for (int i = 0; i <= n; i += 1) {
b[i] = (a[i] * tpw % mod * f[i] % mod);
tpw = (tpw * t) % mod;
}
for (int i = 0; i <= n; i += 1) {
c[n - i] = rf[i];
}
a = mul(b, c);
vector<ll> d(n + 1);
ll rvt = inv(t);
ll rvt_pw = 1;
for (int i = 0; i <= n; i += 1) {
d[i] = rvt_pw * rf[i] % mod * a[i + n] % mod;
rvt_pw = (rvt_pw * rvt) % mod;
}
return d;
}
vector<ll> rev_polynom(const vector<ll>& a, int n) {
int sz = a.size();
vector<ll> b = { inv(a[0]) };
int m = 1;
int mk = 0;
while (m < n) {
int m2 = m + m;
int m4 = m2 + m2;
b.resize(m4);
fft(b, mk + 2);
vector<ll> nb(m4);
for (int i = 0; i < sz && i < m2; i += 1) {
nb[i] = a[i];
}
fft(nb, mk + 2);
for (int i = 0; i < m4; i += 1) {
nb[i] = (2 * b[i] - nb[i] * b[i] % mod * b[i]) % mod;
if (nb[i] < 0) {
nb[i] += mod;
}
}
rev_fft(nb, mk + 2);
nb.resize(m2);
b.swap(nb);
m = m2;
mk += 1;
}
b.resize(n);
return b;
}
vector<ll> square_root(const vector<ll>& a, int n) {
ll sz = a.size();
ll rv2 = inv(2);
vector<ll> b = { 1 };
int m = 1;
while (m < n) {
ll m2 = m + m;
vector<ll> rvb = rev_polynom(b, m2);
vector<ll> ab(m2);
for (int i = 0; i < m2 && i < sz; i += 1) {
ab[i] = a[i];
}
ab = mul(ab, rvb);
ab.resize(m2);
b.resize(m2);
for (int i = 0; i < m2; i += 1) {
b[i] = (rv2 * ((b[i] + ab[i]) % mod)) % mod;
}
m = m2;
}
b.resize(n);
return b;
}
vector<ll> derivative(vector<ll> a) {
int n = a.size();
if (n == 0) {
return {};
}
for (int i = 0; i + 1 < n; i += 1) {
a[i] = (a[i + 1] * (i + 1)) % mod;
}
a.resize(n - 1);
return a;
}
vector<ll> integrate(vector<ll> a) {
int n = a.size();
a.resize(n + 1);
for (int i = n; i > 0; i -= 1) {
a[i] = (a[i - 1] * rvx[i]) % mod;
}
a[0] = 0;
return a;
}
vector<ll> sin_polynomial(int n) {
vector<ll> a(n, 0);
ll fct = 1;
for (int i = 0; i < n; i += 1) {
if (i != 0) {
fct = (fct * i) % mod;
}
if (i % 2 == 1) {
int sign = 1;
if ((i / 2) % 2 == 1) {
sign = -1;
}
a[i] = (mod + sign * inv(fct)) % mod;
}
}
return a;
}
vector<ll> cos_polynomial(int n) {
vector<ll> a(n, 0);
ll fct = 1;
for (int i = 0; i < n; i += 1) {
if (i != 0) {
fct = (fct * i) % mod;
}
if (i % 2 == 0) {
int sign = 1;
if ((i / 2) % 2 == 1) {
sign = -1;
}
a[i] = (mod + sign * inv(fct)) % mod;
}
}
return a;
}
vector<ll> super_cos_polynomial(int n, int k) {
vector<ll> a(n, 0);
ll fct = 1;
for (int i = 0; i < n; i += 1) {
if (i != 0) {
fct = (fct * i) % mod;
}
if (i % k == 0) {
int sign = 1;
if ((i / k) % 2 == 1) {
sign = -1;
}
a[i] = (mod + sign * inv(fct)) % mod;
}
}
return a;
}
vector<ll> logarithm(const vector<ll>& a, int n) {
if (n == 0) {
return {};
}
vector<ll> b = integrate(mul(derivative(a), rev_polynom(a, n)));
b.resize(n);
return b;
}
vector<ll> exponent(const vector<ll>& a, int n) {
vector<ll> b = { 1 };
int m = 1;
while (m < n) {
int m2 = m + m;
vector<ll> t = logarithm(b, m2);
for (int i = 0; i < m2 && i < (int)a.size(); i += 1) {
t[i] = (t[i] - a[i]);
if (t[i] < 0) {
t[i] += mod;
}
}
vector<ll> q = fft::mul(t, b);
q.resize(m2);
b.resize(m2);
for (int i = 0; i < m2; i += 1) {
b[i] -= q[i];
if (b[i] < 0) {
b[i] += mod;
}
}
m = m2;
}
b.resize(n);
return b;
}
vector<ll> solve_differential(const vector<ll>& a, const vector<ll>& b, int n) {
vector<ll> e = exponent(integrate(a), n);
vector<ll> result = mul(e, integrate(mul(b, rev_polynom(e, n))));
result.resize(n);
return result;
}
vector<ll> pure_exponent(int n, ll k = 1) {
if (n == 0) {
return {};
}
k %= mod;
if (k < 0) {
k += mod;
}
vector<ll> rs(n);
rs[0] = 1;
ll rv_fct = 1;
for (int i = 1; i < n; i += 1) {
rv_fct = (rv_fct * rvx[i]) % mod * k % mod;
rs[i] = rv_fct;
}
return rs;
}
} // namespace fft
vector<ll> f, invf;
void build_f(int n) {
f.resize(n + 1);
invf.resize(n + 1);
f[0] = invf[0] = 1;
for (int i = 1; i <= n; i += 1) {
f[i] = (f[i - 1] * i) % mod;
invf[i] = inv(f[i]);
}
}
ll cnk(ll n, ll k) {
if (n < 0 or k < 0 or k > n) return 0;
return (f[n] * invf[k] % mod * invf[n - k] % mod);
}
ll catalan(ll n) {
return cnk(2 * n, n) * inv(n + 1) % mod;
}
ll choose_repeating(ll n, ll k) {
return cnk(n + k - 1, k);
}
ll solve_smart(int n, int m) {
build_f(n + m + 1e5);
ll result = 1;
for (int t = 1; t <= m; t += 1) {
result = (result + choose_repeating(2 * t + 1, n - 1)) % mod;
for (int cnt = 1; 2 * cnt + 1 <= n and cnt <= t - 1; cnt += 1) {
ll coeff = (cnk(t, cnt) * cnk(t - 1, cnt) % mod * 2 +
cnk(t, cnt + 1) * cnk(t - 1, cnt) % mod +
cnk(t - 1, cnt + 1) * cnk(t, cnt)) % mod;
coeff = coeff * inv(2 * t + 1) % mod;
result = (result + coeff * cnk(2 * t + n - 2 * cnt - 1, n - 2 * cnt - 1)) % mod;
}
}
return result;
}
ll solve(int n, int m) {
build_f(2 * n + 2 * m + 1e5);
ll result = 1;
for (int t = 1; t <= m; t += 1) {
result = (result + choose_repeating(2 * t + 1, n - 1)) % mod;
}
for (int mask = 0; mask <= 2; mask += 1) {
int x = mask % 2;
int y = (mask / 2) % 2;
int mx_cnt = (n - 1) / 2;
vector<ll> a(mx_cnt + 1);
for (int cnt = 1; cnt <= mx_cnt; cnt += 1) {
a[cnt] = invf[cnt + x] * invf[cnt + y] % mod;
if (mask == 0) {
a[cnt] = (a[cnt] * 2) % mod;
}
a[cnt] = a[cnt] * invf[n - 2 * cnt - 1] % mod;
}
vector<ll> b(m);
for (int t_cnt = y + 1; t_cnt <= m - 1; t_cnt += 1) {
b[t_cnt] = invf[t_cnt - x] * invf[t_cnt - 1 - y] % mod;
b[t_cnt] = b[t_cnt] * f[n - 1 + 2 * t_cnt] % mod;
}
auto c = fft::mul(a, b);
for (int t = 1; t <= m and t < len(c); t += 1) {
result += c[t] * f[t] % mod * f[t - 1] % mod * invf[2 * t] % mod * inv(2 * t + 1) % mod;
}
result %= mod;
}
return result;
}
ll solve_slow(int n, int m) {
vector<vector<vector<ll>>> dp(n + 1, vector<vector<ll>>(m + 1, vector<ll>(m + 1)));
rep(i, m + 1) dp[0][0][m] = 1;
rep(i, n) {
rep(x, m + 1) rep(y, m + 1 - x) {
rep(k, x) dp[i + 1][x][y] += dp[i][k][y];
for (int k = m; k >= y; k -= 1) dp[i + 1][x][y] += dp[i][x][k];
dp[i + 1][x][y] %= mod;
}
}
ll result = 0;
rep(x, m + 1) rep(y, m + 1 - x) result += dp[n][x][y];
result %= mod;
return result;
}
void stress() {
int c = 0;
while (1) {
cout << ++c << "\n";
int n = rnd() % 100 + 1;
int m = rnd() % 100 + 1;
auto my_rs = solve(n, m);
auto right_rs = solve_slow(n, m);
if (my_rs != right_rs) {
cout << n << " " << m << "\n";
break;
}
}
}
int32_t main() {
fft::build(20);
// stress();
if (1) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
}
int n, m;
cin >> n >> m;
cout << solve(n, m) << "\n";
return 0;
}
这程序好像有点Bug,我给组数据试试?
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 132ms
memory: 29444kb
input:
2 2
output:
9
result:
ok 1 number(s): "9"
Test #2:
score: 0
Accepted
time: 1028ms
memory: 88676kb
input:
500000 500000
output:
984531374
result:
ok 1 number(s): "984531374"
Test #3:
score: 0
Accepted
time: 332ms
memory: 54260kb
input:
500000 1
output:
219705876
result:
ok 1 number(s): "219705876"
Test #4:
score: 0
Accepted
time: 456ms
memory: 62260kb
input:
1 500000
output:
500001
result:
ok 1 number(s): "500001"
Test #5:
score: 0
Accepted
time: 953ms
memory: 81840kb
input:
500000 353535
output:
33730077
result:
ok 1 number(s): "33730077"
Test #6:
score: 0
Accepted
time: 985ms
memory: 83052kb
input:
353535 500000
output:
182445298
result:
ok 1 number(s): "182445298"
Test #7:
score: 0
Accepted
time: 1015ms
memory: 88684kb
input:
499999 499999
output:
933220940
result:
ok 1 number(s): "933220940"
Test #8:
score: 0
Accepted
time: 1030ms
memory: 88692kb
input:
499999 499998
output:
899786345
result:
ok 1 number(s): "899786345"
Test #9:
score: 0
Accepted
time: 925ms
memory: 79208kb
input:
377773 400009
output:
206748715
result:
ok 1 number(s): "206748715"
Test #10:
score: 0
Accepted
time: 520ms
memory: 61820kb
input:
499999 100001
output:
482775773
result:
ok 1 number(s): "482775773"
Test #11:
score: 0
Accepted
time: 971ms
memory: 86092kb
input:
444445 488884
output:
70939759
result:
ok 1 number(s): "70939759"
Test #12:
score: 0
Accepted
time: 1003ms
memory: 85576kb
input:
488885 444449
output:
591315327
result:
ok 1 number(s): "591315327"
Test #13:
score: 0
Accepted
time: 337ms
memory: 54192kb
input:
500000 111
output:
313439156
result:
ok 1 number(s): "313439156"
Test #14:
score: 0
Accepted
time: 920ms
memory: 79628kb
input:
333333 444444
output:
403492103
result:
ok 1 number(s): "403492103"
Test #15:
score: 0
Accepted
time: 909ms
memory: 81232kb
input:
499994 343433
output:
334451699
result:
ok 1 number(s): "334451699"
Test #16:
score: 0
Accepted
time: 941ms
memory: 83776kb
input:
477774 411113
output:
63883341
result:
ok 1 number(s): "63883341"
Test #17:
score: 0
Accepted
time: 621ms
memory: 63952kb
input:
123456 432109
output:
238795570
result:
ok 1 number(s): "238795570"
Test #18:
score: 0
Accepted
time: 916ms
memory: 73000kb
input:
131331 467777
output:
834790039
result:
ok 1 number(s): "834790039"
Test #19:
score: 0
Accepted
time: 350ms
memory: 54256kb
input:
500000 2
output:
304727284
result:
ok 1 number(s): "304727284"
Test #20:
score: 0
Accepted
time: 133ms
memory: 29380kb
input:
1111 111
output:
98321603
result:
ok 1 number(s): "98321603"
Test #21:
score: 0
Accepted
time: 962ms
memory: 85156kb
input:
416084 493105
output:
916827025
result:
ok 1 number(s): "916827025"
Test #22:
score: 0
Accepted
time: 320ms
memory: 42024kb
input:
53888 138663
output:
57263952
result:
ok 1 number(s): "57263952"
Test #23:
score: 0
Accepted
time: 608ms
memory: 65892kb
input:
219161 382743
output:
304889787
result:
ok 1 number(s): "304889787"
Test #24:
score: 0
Accepted
time: 558ms
memory: 60140kb
input:
181392 318090
output:
12528742
result:
ok 1 number(s): "12528742"
Test #25:
score: 0
Accepted
time: 607ms
memory: 64496kb
input:
135930 422947
output:
554153000
result:
ok 1 number(s): "554153000"
Test #26:
score: 0
Accepted
time: 529ms
memory: 58372kb
input:
280507 210276
output:
812816587
result:
ok 1 number(s): "812816587"
Test #27:
score: 0
Accepted
time: 934ms
memory: 75344kb
input:
253119 420465
output:
124024302
result:
ok 1 number(s): "124024302"
Test #28:
score: 0
Accepted
time: 480ms
memory: 59544kb
input:
446636 97448
output:
150388382
result:
ok 1 number(s): "150388382"
Test #29:
score: 0
Accepted
time: 460ms
memory: 54636kb
input:
284890 126665
output:
786559507
result:
ok 1 number(s): "786559507"
Test #30:
score: 0
Accepted
time: 239ms
memory: 39928kb
input:
186708 28279
output:
607509026
result:
ok 1 number(s): "607509026"
Extra Test:
score: 0
Extra Test Passed