QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#499976 | #6551. Forever Young | pandapythoner | TL | 1ms | 3668kb | C++23 | 4.5kb | 2024-07-31 20:42:16 | 2024-07-31 20:42:16 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define flt double
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define rep(i, n) for(int i = 0; i < n; i += 1)
#define len(a) ((int)(a).size())
const ll inf = 1e18;
mt19937 rnd(234);
const ll mod = 998244353;
ll bin_pow(ll x, ll n) {
assert(0 <= x and x < mod and n >= 0);
ll rs = 1;
for (ll i = 1, a = x; i <= n; i *= 2, a = a * a % mod) {
if (n & i) {
rs = rs * a % mod;
}
}
return rs;
}
ll inv(ll x) {
assert(x != 0);
ll a = bin_pow(x, mod - 2);
assert(a * x % mod == 1);
return a;
}
ll solve(vector<int> a, vector<int> b, ll k) {
int n = len(a), m = len(b);
auto get_dp = [&](vector<int> a) -> map<vector<int>, ll> {
map<vector<int>, ll> dp;
dp[a] = 1;
queue<vector<int>> q;
q.push(a);
while (!q.empty()) {
auto v = q.front();
q.pop();
ll val = dp[v];
rep(i, len(v)) {
assert(v[i] > 0);
if (i + 1 < len(v) and v[i] == v[i + 1]) continue;
auto to = v;
to[i] -= 1;
if (to.back() == 0) {
to.pop_back();
}
auto it = dp.find(to);
if (it == dp.end()) {
dp[to] = dp[v];
q.push(to);
} else {
it->second = (it->second + val);
if (it->second >= mod) {
it->second -= mod;
}
}
}
}
return dp;
};
auto dpa = get_dp(a);
auto dpb = get_dp(b);
ll fsize = 2 * k + 1000;
vector<ll> f(fsize + 1), invf(fsize + 1);
f[0] = invf[0] = 1;
for (int i = 1; i <= fsize; i += 1) {
f[i] = (f[i - 1] * i) % mod;
invf[i] = inv(f[i]);
}
ll sma = 0, smb = 0;
rep(i, n) sma += a[i];
rep(i, m) smb += b[i];
ll inv2 = inv(2);
ll result = 0;
for (auto [v, vala] : dpa) {
auto it = dpb.find(v);
if (it == dpb.end()) continue;
ll valb = it->second;
ll msum = 0;
rep(i, len(v)) msum += v[i];
assert(msum <= sma and msum <= smb);
ll down = sma - msum;
ll up = smb - msum;
if (down + up > k or (down + up) % 2 != k % 2) continue;
ll coeff = f[down + up] * invf[down] % mod * invf[up] % mod;
for (ll x = down + up; x < k; x += 2) {
coeff = coeff * (x + 1) % mod * (x + 2) % mod * inv2 % mod;
}
coeff = coeff * invf[(k - down - up) / 2] % mod;
result = (result + coeff * vala % mod * valb) % mod;
}
return result;
}
ll solve_slow(vector<int> a, vector<int> b, ll k) {
vector<map<vector<int>, ll>> dp(k + 1);
dp[0][a] = 1;
for (ll i = 0; i < k; i += 1) {
for (auto [v, val] : dp[i]) {
for (int j = 0; j < len(v); j += 1) {
for (auto d : { -1, 1 }) {
auto to = v;
to[j] += d;
if (j > 0 and to[j - 1] < to[j]) continue;
if (j + 1 < len(to) and to[j] < to[j + 1]) continue;
if (to.back() == 0) {
to.pop_back();
}
dp[i + 1][to] = (dp[i + 1][to] + dp[i][v]) % mod;
}
}
auto to = v;
to.push_back(1);
dp[i + 1][to] = (dp[i + 1][to] + dp[i][v]) % mod;
}
}
return dp[k][b];
}
void stress() {
ll c = 0;
while (1) {
cout << ++c << "\n";
vector<int> a, b;
for (int x = rnd() % 5 + 1; x > 0; x = rnd() % (x + 1)) a.push_back(x);
for (int x = rnd() % 5 + 1; x > 0; x = rnd() % (x + 1)) b.push_back(x);
int k = rnd() % 10 + 1;
ll right_res = solve_slow(a, b, k);
ll my_res = solve(a, b, k);
if (right_res != my_res) {
break;
}
}
}
int32_t main() {
// stress();
if (1) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
}
int n; cin >> n;
vector<int> a(n); rep(i, n) cin >> a[i];
int m; cin >> m;
vector<int> b(m); rep(i, m) cin >> b[i];
ll k;
cin >> k;
ll result = solve(a, b, k);
cout << result << "\n";
return 0;
}
/*
2
1 1
2
2 2
6
*/
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 1ms
memory: 3572kb
input:
3 3 2 1 3 3 2 1 2
output:
7
result:
ok 1 number(s): "7"
Test #2:
score: 0
Accepted
time: 1ms
memory: 3668kb
input:
3 3 2 1 3 3 2 1 1111
output:
0
result:
ok 1 number(s): "0"
Test #3:
score: 0
Accepted
time: 1ms
memory: 3564kb
input:
0 0 10
output:
945
result:
ok 1 number(s): "945"
Test #4:
score: -100
Time Limit Exceeded
input:
10 10 9 8 7 6 5 4 4 4 3 10 10 9 8 7 6 5 4 4 4 3 1000000