QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#420466 | #8715. 放苹果 | Scintilla# | AC ✓ | 378ms | 81848kb | C++20 | 13.3kb | 2024-05-24 18:52:04 | 2024-05-24 18:52:04 |
Judging History
answer
#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <bitset>
#include <vector>
#include <random>
#include <cassert>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <functional>
#include <unordered_map>
using namespace std;
#define rep(i, s, e) for (int i = s; i <= e; ++i)
#define per(i, s, e) for (int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)
#define pv(a) cout << #a << " = " << a << endl
#define pa(a, l, r) cout << #a " : "; rep(_, l, r) cout << a[_] << " \n"[_ == r]
const int P = 998244353;
const int N = 2e5 + 10;
int read() {
int x = 0, f = 1; char c = getchar();
for (; c < '0' || c > '9'; c = getchar()) if (c == '-') f = -1;
for (; c >= '0' && c <= '9'; c = getchar()) x = x * 10 + (c - 48);
return x * f;
}
int inc(int a, int b) { return (a += b) >= P ? a - P : a; }
int dec(int a, int b) { return (a -= b) < 0 ? a + P : a; }
int mul(int a, int b) { return 1ll * a * b % P; }
void add(int &a, int b) { (a += b) >= P ? a -= P : 1; }
void sub(int &a, int b) { (a -= b) < 0 ? a += P : 1; }
int sgn(int x) { return x & 1 ? P - 1 : 1; }
int qpow(int a, int b) { int res = 1; for (; b; b >>= 1, a = mul(a, a)) if (b & 1) res = mul(res, a); return res; }
int fac[N], inv[N], finv[N];
void init(int w) {
fac[0] = inv[1] = finv[0] = 1;
rep(i, 1, w) fac[i] = mul(fac[i - 1], i);
rep(i, 2, w) inv[i] = P - mul(P / i, inv[P % i]);
rep(i, 1, w) finv[i] = mul(finv[i - 1], inv[i]);
}
int C(int a, int b) {
if (a < b || b < 0) return 0;
return mul(fac[a], mul(finv[b], finv[a - b]));
}
namespace Poly {
const int ng = 3;
const int ngi = (P + 1) / 3;
const int L = 20;
const int N = 1 << L;
int w[N], pinv[N];
void init_w() {
w[0] = pinv[1] = 1;
rep(i, 0, L - 1) {
w[1 << i] = qpow(ng, (P - 1) >> (i + 2));
}
rep(i, 2, N - 1) {
w[i] = mul(w[i & -i], w[i & i - 1]);
pinv[i] = P - mul(P / i, pinv[P % i]);
}
}
void dif(int *f, int n) {
int t, k, *tw, *tf, *pf;
for (k = n >> 1; k; k >>= 1) {
for (tw = w, tf = f; tf != f + n; ++tw, tf += k << 1) {
for (pf = tf; pf != tf + k; ++pf) {
t = mul(pf[k], *tw);
pf[k] = dec(*pf, t), add(*pf, t);
}
}
}
}
void dit(int *f, int n) {
int t, k, *tw, *tf, *pf;
for (k = 1; k < n; k <<= 1) {
for (tw = w, tf = f; tf != f + n; ++tw, tf += k << 1) {
for (pf = tf; pf != tf + k; ++pf) {
t = pf[k];
pf[k] = mul(dec(*pf, t), *tw), add(*pf, t);
}
}
}
int iv = P - (P - 1) / n;
rep(i, 0, n - 1) f[i] = mul(f[i], iv);
reverse(f + 1, f + n);
}
int f0[N];
struct poly {
vector <int> a;
int size() { return a.size(); }
int& operator [] (int i) { return a[i]; }
void clear() { a.clear(); }
void resize(int n) { a.resize(n); }
poly(int n = 0, int x = 0) { a.resize(n, x); }
poly(const vector <int> &o) { a = o; }
poly(const poly &o) { a = o.a; }
void operator += (poly b) {
int n = size(), m = b.size();
resize(max(n, m));
rep(i, 0, m - 1) add(a[i], b[i]);
}
friend poly operator + (poly a, poly b) {
a += b;
return a;
}
void operator -= (poly b) {
int n = size(), m = b.size();
resize(max(n, m));
rep(i, 0, m - 1) sub(a[i], b[i]);
}
friend poly operator - (poly a, poly b) {
a -= b;
return a;
}
void operator <<= (int k) {
int n = size();
resize(n + k);
per(i, n - 1, 0) a[i + k] = a[i];
rep(i, 0, k - 1) a[i] = 0;
}
friend poly operator << (poly a, int k) {
a <<= k;
return a;
}
void operator *= (int b) {
int n = size();
rep(i, 0, n - 1) a[i] = mul(a[i], b);
}
friend poly operator * (poly a, int b) {
a *= b;
return a;
}
void operator /= (int b) {
int n = size(), iv = qpow(b, P - 2);
rep(i, 0, n - 1) a[i] = mul(a[i], iv);
}
friend poly operator / (poly a, int b) {
a /= b;
return a;
}
void operator >>= (int k) {
int n = size();
if (n <= k) {
*this = poly();
return;
}
rep(i, k, n - 1) a[i - k] = a[i];
resize(n - k);
}
friend poly operator >> (poly a, int k) {
a >>= k;
return a;
}
poly diff() {
int n = size();
if (!n) return poly();
poly res(n - 1);
rep(i, 1, n - 1) res[i - 1] = mul(a[i], i);
return res;
}
poly intg() {
int n = size();
poly res(n + 1);
rep(i, 0, n - 1) res[i + 1] = mul(a[i], pinv[i + 1]);
return res;
}
void dft() {
int n = size();
assert(n == (n & -n));
rep(i, 0, n - 1) f0[i] = a[i];
dif(f0, n);
rep(i, 0, n - 1) a[i] = f0[i];
}
void idft() {
int n = size();
assert(n == (n & -n));
rep(i, 0, n - 1) f0[i] = a[i];
dit(f0, n);
rep(i, 0, n - 1) a[i] = f0[i];
}
void operator *= (poly b) {
int n = size(), m = b.size();
int lim = 1;
while (lim < n + m - 1) lim <<= 1;
resize(lim), dft();
b.resize(lim), b.dft();
rep(i, 0, lim - 1) a[i] = mul(a[i], b[i]);
idft(), resize(n + m - 1);
}
friend poly operator * (poly a, poly b) {
a *= b;
return a;
}
void operator /= (poly b) {
int n = size(), m = b.size();
if (n < m) {
resize(0);
return;
}
reverse(a.begin(), a.end());
reverse(b.a.begin(), b.a.end());
resize(n - m + 1), b.resize(n - m + 1);
*this *= b.inv(), resize(n - m + 1);
reverse(a.begin(), a.end());
}
friend poly operator / (poly a, poly b) {
a /= b;
return a;
}
void operator %= (poly b) {
int n = size(), m = b.size();
if (n < m) return;
poly q = *this / b;
*this -= q * b, resize(m - 1);
}
friend poly operator % (poly a, poly b) {
a %= b;
return a;
}
pair <poly, poly> div(poly b) {
int n = size(), m = b.size();
if (n < m) return make_pair(poly(), *this);
poly q = *this / b, r = *this - b * q;
r.resize(m - 1);
return make_pair(q, r);
}
poly inv() {
int n = size();
assert(n);
assert(a[0]);
poly res(1), cp;
res[0] = qpow(a[0], P - 2);
for (int o = 2; o < (n << 1); o <<= 1) {
cp = a, cp.resize(o), cp.resize(o << 1);
res.resize(o << 1);
res.dft(), cp.dft();
for (int i = 0; i < o << 1; ++i) {
res[i] = mul(res[i], dec(2, mul(res[i], cp[i])));
}
res.idft(), res.resize(o);
}
res.resize(n);
return res;
}
poly ln() {
int n = size();
assert(n);
assert(a[0] == 1);
poly res = (diff() * inv()).intg();
res.resize(n);
return res;
}
poly exp() {
int n = size();
assert(n);
assert(!a[0]);
poly res(1), fln, cp;
res[0] = 1;
for (int o = 2; o < (n << 1); o <<= 1) {
cp = a, cp.resize(o), cp.resize(o << 1);
res.resize(o), fln = res.ln(), fln.resize(o << 1);
res.resize(o << 1);
res.dft(), fln.dft(), cp.dft();
for (int i = 0; i < o << 1; ++i) {
res[i] = mul(res[i], dec(1, dec(fln[i], cp[i])));
}
res.idft(), res.resize(o);
}
res.resize(n);
return res;
}
poly pow(int k) {
poly cp = *this;
assert(cp.size());
assert(cp.a[0]);
int c = cp.a[0];
cp /= c;
cp = (cp.ln() * k).exp();
cp *= qpow(c, k);
return cp;
}
poly sqrt() {
int n = size();
assert(n);
assert(a[0] == 1);
poly res(1), fiv, cp;
res[0] = a[0];
for (int o = 2; o < (n << 1); o <<= 1) {
cp = a, cp.resize(o), cp.resize(o << 1);
res.resize(o), fiv = res.inv(), fiv.resize(o << 1);
res.resize(o << 1);
res.dft(), fiv.dft(), cp.dft();
for (int i = 0; i < o << 1; ++i) {
res[i] = mul(inc(mul(res[i], res[i]), cp[i]), mul((P + 1) / 2, fiv[i]));
}
res.idft(), res.resize(o);
}
res.resize(n);
return res;
}
} ;
namespace linear_recur {
int linear_recur(poly f, poly a, int n) {
int k = f.size(), ans = 0;
poly g(k + 1), cur(2), res(1);
rep(i, 0, k - 1) g[i] = dec(0, f[k - 1 - i]);
g[k] = cur[1] = res[0] = 1;
for (int b = n; b; b >>= 1, cur = cur * cur % g) {
if (b & 1) res = res * cur % g;
}
rep(i, 0, k - 1) add(ans, mul(res[i], a[i]));
return ans;
}
}
namespace evaluation_and_interpolation {
poly MulT(poly a, poly b) {
int n = a.size(), m = b.size();
reverse(b.a.begin(), b.a.end());
b *= a, a.resize(n - m + 1);
rep(i, 0, n - m) a[i] = b[i + m - 1];
return a;
}
// n <= _N
const int _L = 17;
const int _N = 1 << _L;
poly M[_N << 2], MD, Q[_N << 2], F[_N << 2], G[_N << 2];
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)
vector <int> multi_evaluation(poly f, vector <int> a) {
int n = f.size(), m = a.size();
if (n < m) f.resize(n = m);
if (n > m) a.resize(n);
vector <int> ans;
auto dfs1 = [&](auto self, int u, int l, int r) {
if (l == r) {
Q[u] = poly(vector <int> ({ 1, dec(0, a[l]) }));
return;
}
self(self, ls, l, mid), self(self, rs, mid + 1, r);
Q[u] = Q[ls] * Q[rs];
} ;
auto dfs2 = [&](auto self, int u, int l, int r) {
if (l >= m) return;
if (l == r) {
ans.emplace_back(F[u][0]);
return;
}
F[ls] = MulT(F[u], Q[rs]), F[rs] = MulT(F[u], Q[ls]);
self(self, ls, l, mid), self(self, rs, mid + 1, r);
} ;
dfs1(dfs1, 1, 0, n - 1);
poly t = Q[1].inv();
t.resize(n);
reverse(t.a.begin(), t.a.end());
t *= f;
F[1].a.resize(n);
rep(i, 0, n - 1) F[1][i] = t[i + n - 1];
dfs2(dfs2, 1, 0, n - 1);
assert(ans.size() == m);
return ans;
}
poly interpolation(vector <int> a, vector <int> b) {
int n = a.size();
auto dfs0 = [&](auto self, int u, int l, int r) {
if (l == r) {
M[u] = poly(vector <int> ({ dec(0, a[l]), 1 }));
return;
}
self(self, ls, l, mid), self(self, rs, mid + 1, r);
M[u] = M[ls] * M[rs];
} ;
auto dfs1 = [&](auto self, int u, int l, int r) {
if (l == r) {
Q[u] = poly(vector <int> ({ 1, dec(0, a[l]) }));
return;
}
self(self, ls, l, mid), self(self, rs, mid + 1, r);
Q[u] = Q[ls] * Q[rs];
} ;
auto dfs2 = [&](auto self, int u, int l, int r) {
if (l == r) {
G[u] = poly(vector <int> ({ mul(b[l], qpow(F[u][0], P - 2)) }));
return;
}
F[ls] = MulT(F[u], Q[rs]), F[rs] = MulT(F[u], Q[ls]);
self(self, ls, l, mid), self(self, rs, mid + 1, r);
G[u] = G[ls] * M[rs] + M[ls] * G[rs];
} ;
dfs0(dfs0, 1, 0, n - 1);
dfs1(dfs1, 1, 0, n - 1);
MD.resize(n);
rep(i, 0, n - 1) MD[i] = mul(i + 1, M[1][i + 1]);
poly t = Q[1].inv();
t.resize(n);
reverse(t.a.begin(), t.a.end());
t *= MD;
F[1].resize(n);
rep(i, 0, n - 1) F[1][i] = t[i + n - 1];
dfs2(dfs2, 1, 0, n - 1);
return G[1];
}
#undef ls
#undef rs
#undef mid
}
}
using Poly::init_w;
using Poly::poly;
using Poly::linear_recur::linear_recur;
using Poly::evaluation_and_interpolation::multi_evaluation;
using Poly::evaluation_and_interpolation::interpolation;
int n, m;
poly f2[N];
poly fk2(int k) {
// cout << "fk2 k = " << k << endl;
assert(k >= 0);
if (!k) return poly(1, 1);
else if (k == 1) {
poly res(2);
res[0] = m, res[1] = P - 1;
return res;
}
if (f2[k].size()) return f2[k];
return f2[k] = fk2(k / 2) * fk2(k - k / 2);
}
poly solve(int l, int r) {
if (l == r) {
return poly(1, mul(min(l, n - l), C(n, l)));
}
int mid = (l + r) >> 1;
// cout << "----- l, r, mid = " << l << ' ' << r << ' ' << mid << endl;
poly v = fk2(mid - l + 1);
// pv(v.size());
// pa(v, 0, v.size() - 1);
poly res = (solve(l, mid) << (r - mid)) + v * solve(mid + 1, r);
// pv(res.size());
// pa(res, 0, res.size() - 1);
return res;
}
poly calc() {
poly f(n + 2), g(n + 2);
rep(i, 0, n + 1) {
f[i] = mul(qpow(m, i), finv[i]);
g[i] = finv[i];
}
// pa(f, 0, n + 1);
// pa(g, 0, n + 1);
f >>= 1, g >>= 1;
f *= g.inv();
rep(i, 0, n) f[i] = mul(f[i], fac[i]);
return f;
}
int main() {
init(N - 5);
init_w();
n = read(), m = read();
poly res = solve(0, n);
// pa(res, 0, n);
poly pw = calc();
// pa(pw, 0, n);
sub(pw[0], 1);
int ans = 0;
rep(i, 0, n) {
add(ans, mul(res[i], pw[i]));
}
// pv(ans);
printf("%d\n", ans);
return 0;
}
这程序好像有点Bug,我给组数据试试?
详细
Test #1:
score: 100
Accepted
time: 19ms
memory: 70436kb
input:
2 3
output:
8
result:
ok 1 number(s): "8"
Test #2:
score: 0
Accepted
time: 20ms
memory: 70172kb
input:
3 3
output:
36
result:
ok 1 number(s): "36"
Test #3:
score: 0
Accepted
time: 19ms
memory: 70148kb
input:
1 1
output:
0
result:
ok 1 number(s): "0"
Test #4:
score: 0
Accepted
time: 9ms
memory: 70156kb
input:
1 2
output:
0
result:
ok 1 number(s): "0"
Test #5:
score: 0
Accepted
time: 23ms
memory: 70236kb
input:
1 3
output:
0
result:
ok 1 number(s): "0"
Test #6:
score: 0
Accepted
time: 16ms
memory: 70228kb
input:
2 1
output:
0
result:
ok 1 number(s): "0"
Test #7:
score: 0
Accepted
time: 12ms
memory: 70212kb
input:
3 1
output:
0
result:
ok 1 number(s): "0"
Test #8:
score: 0
Accepted
time: 12ms
memory: 70328kb
input:
3719 101
output:
78994090
result:
ok 1 number(s): "78994090"
Test #9:
score: 0
Accepted
time: 19ms
memory: 70584kb
input:
2189 1022
output:
149789741
result:
ok 1 number(s): "149789741"
Test #10:
score: 0
Accepted
time: 19ms
memory: 70316kb
input:
2910 382012013
output:
926541722
result:
ok 1 number(s): "926541722"
Test #11:
score: 0
Accepted
time: 283ms
memory: 79644kb
input:
131072 3837829
output:
487765455
result:
ok 1 number(s): "487765455"
Test #12:
score: 0
Accepted
time: 368ms
memory: 81720kb
input:
183092 100000000
output:
231786691
result:
ok 1 number(s): "231786691"
Test #13:
score: 0
Accepted
time: 376ms
memory: 81848kb
input:
197291 937201572
output:
337054675
result:
ok 1 number(s): "337054675"
Test #14:
score: 0
Accepted
time: 378ms
memory: 81740kb
input:
200000 328194672
output:
420979346
result:
ok 1 number(s): "420979346"
Test #15:
score: 0
Accepted
time: 362ms
memory: 81800kb
input:
200000 1000000000
output:
961552572
result:
ok 1 number(s): "961552572"
Extra Test:
score: 0
Extra Test Passed