QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#715383 | #5546. Sharing Bread | ucup-team3519 | AC ✓ | 225ms | 33216kb | C++17 | 10.8kb | 2024-11-06 11:43:14 | 2024-11-06 11:43:14 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define V vector
#define all0(x) (x).begin(),(x).end()
#define all1(x) (x).begin()+1,(x).end()
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define cin std::cin
#define cout std::cout
typedef long long LL;
typedef pair<int, int> pi;
typedef pair<LL, LL> pl;
//const int N = 2e5 + 20;
const int INF = 2e9 + 1000;
const LL INFLL = 8e18 + 1000;
mt19937 mrand(chrono::steady_clock().now().time_since_epoch().count());
//模板区域~~~~~~~
const int mod = 998244353;
inline void add(int &x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) {
x += mod;
}
}
inline int mul(int x, int y) {
return (long long) x * y % mod;
}
inline int power(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) {
if (y & 1) {
res = mul(res, x);
}
}
return res;
}
inline int inv(int a) {
a %= mod;
if (a < 0) {
a += mod;
}
int b = mod, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a;
swap(a, b);
u -= t * v;
swap(u, v);
}
if (u < 0) {
u += mod;
}
return u;
}
namespace ntt {
int base = 1, root = -1, max_base = -1;
V<int> rev = {0, 1}, roots = {0, 1};
void init() {
int temp = mod - 1;
max_base = 0;
while (temp % 2 == 0) {
temp >>= 1;
++max_base;
}
root = 2;
while (true) {
if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
break;
}
++root;
}
}
void ensure_base(int nbase) {
if (max_base == -1) {
init();
}
if (nbase <= base) {
return;
}
assert(nbase <= max_base);
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
}
roots.resize(1 << nbase);
while (base < nbase) {
int z = power(root, 1 << (max_base - 1 - base));
for (int i = 1 << (base - 1); i < 1 << base; ++i) {
roots[i << 1] = roots[i];
roots[i << 1 | 1] = mul(roots[i], z);
}
++base;
}
}
void dft(V<int> &a) {
int n = a.size(), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = base - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
a[j + k] = (x + y) % mod;
a[j + k + i] = (x + mod - y) % mod;
}
}
}
}
V<int> multiply(V<int> a, V<int> b) {
int need = a.size() + b.size() - 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
bool equal = a == b;
dft(a);
if (equal) {
b = a;
} else {
dft(b);
}
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(a[i], b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(need);
return a;
}
V<int> inverse(V<int> a) {
int n = a.size(), m = (n + 1) >> 1;
if (n == 1) {
return V<int>(1, inv(a[0]));
} else {
V<int> b = inverse(V<int>(a.begin(), a.begin() + m));
int need = n << 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
dft(a);
dft(b);
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(mod + 2 - mul(a[i], b[i]), b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(n);
return a;
}
}
}
using ntt::multiply;
using ntt::inverse;
V<int>& operator += (V<int> &a, const V<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
add(a[i], b[i]);
}
return a;
}
V<int> operator + (const V<int> &a, const V<int> &b) {
V<int> c = a;
return c += b;
}
V<int>& operator -= (V<int> &a, const V<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
sub(a[i], b[i]);
}
return a;
}
V<int> operator - (const V<int> &a, const V<int> &b) {
V<int> c = a;
return c -= b;
}
V<int>& operator *= (V<int> &a, const V<int> &b) {
if (min(a.size(), b.size()) < 128) {
V<int> c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < c.size(); ++i) {
for (int j = 0; j < b.size(); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
} else {
a = multiply(a, b);
}
return a;
}
V<int> operator * (const V<int> &a, const V<int> &b) {
V<int> c = a;
return c *= b;
}
V<int>& operator /= (V<int> &a, const V<int> &b) {
int n = a.size(), m = b.size();
if (n < m) {
a.clear();
} else {
V<int> c = b;
reverse(a.begin(), a.end());
reverse(c.begin(), c.end());
c.resize(n - m + 1);
a *= inverse(c);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
V<int> operator / (const V<int> &a, const V<int> &b) {
V<int> c = a;
return c /= b;
}
V<int>& operator %= (V<int> &a, const V<int> &b) {
int n = a.size(), m = b.size();
if (n >= m) {
V<int> c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; ++i) {
sub(a[i], c[i]);
}
}
return a;
}
V<int> operator % (const V<int> &a, const V<int> &b) {
V<int> c = a;
return c %= b;
}
V<int> derivative(const V<int> &a) {
int n = a.size();
V<int> b(n - 1);
for (int i = 1; i < n; ++i) {
b[i - 1] = mul(a[i], i);
}
return b;
}
V<int> primitive(const V<int> &a) {
int n = a.size();
V<int> b(n + 1), invs(n + 1);
for (int i = 1; i <= n; ++i) {
invs[i] = i == 1 ? 1 : mul(mod - mod / i, invs[mod % i]);
b[i] = mul(a[i - 1], invs[i]);
}
return b;
}
V<int> logarithm(const V<int> &a) {
V<int> b = primitive(derivative(a) * inverse(a));
b.resize(a.size());
return b;
}
V<int> exponent(const V<int> &a) {
V<int> b(1, 1);
while (b.size() < a.size()) {
V<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
add(c[0], 1);
V<int> old_b = b;
b.resize(b.size() << 1);
c -= logarithm(b);
c *= old_b;
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = c[i];
}
}
b.resize(a.size());
return b;
}
V<int> power(V<int> a, int m) {
int n = a.size(), p = -1;
V<int> b(n);
for (int i = 0; i < n; ++i) {
if (a[i]) {
p = i;
break;
}
}
if (p == -1) {
b[0] = !m;
return b;
}
if ((long long) m * p >= n) {
return b;
}
int mu = power(a[p], m), di = inv(a[p]);
V<int> c(n - m * p);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(a[i + p], di);
}
c = logarithm(c);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(c[i], m);
}
c = exponent(c);
for (int i = 0; i < n - m * p; ++i) {
b[i + m * p] = mul(c[i], mu);
}
return b;
}
V<int> sqrt(const V<int> &a) {
V<int> b(1, 1);
while (b.size() < a.size()) {
V<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
V<int> old_b = b;
b.resize(b.size() << 1);
c *= inverse(b);
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = mul(c[i], (mod + 1) >> 1);
}
}
b.resize(a.size());
return b;
}
V<int> multiply_all(int l, int r, V<V<int>> &all) {
if (l > r) {
return V<int>();
} else if (l == r) {
return all[l];
} else {
int y = (l + r) >> 1;
return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
}
}
V<int> evaluate(const V<int> &f, const V<int> &x) {
int n = x.size();
if (!n) {
return V<int>();
}
V<V<int>> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = V<int>{(mod - x[i]) % mod, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
V<V<int>> down(n * 2);
down[1] = f % up[1];
for (int i = 2; i < n * 2; ++i) {
down[i] = down[i >> 1] % up[i];
}
V<int> y(n);
for (int i = 0; i < n; ++i) {
y[i] = down[i + n][0];
}
return y;
}
V<int> interpolate(const V<int> &x, const V<int> &y) {
int n = x.size();
V<V<int>> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = V<int>{(mod - x[i]) % mod, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
V<int> a = evaluate(derivative(up[1]), x);
for (int i = 0; i < n; ++i) {
a[i] = mul(y[i], inv(a[i]));
}
V<V<int>> down(n * 2);
for (int i = 0; i < n; ++i) {
down[i + n] = V<int>(1, a[i]);
}
for (int i = n - 1; i; --i) {
down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
}
return down[1];
}
constexpr LL qpow(LL a, LL k) {
LL ans = 1;
while (k) {
if (k & 1) ans = ans * a % mod;
k >>= 1;
a = a * a % mod;
}
return ans;
}
V<LL> _fac, _ifac;
void ini_comb(int n) {
_fac.resize(n + 1), _ifac.resize(n + 1);
_fac[0] = 1;
for (int i = 1; i <= n; ++i) {
_fac[i] = _fac[i - 1] * i % mod;
}
_ifac[n] = qpow(_fac[n], mod - 2);
for (int i = n - 1; i >= 0; --i) {
_ifac[i] = _ifac[i + 1] * (i + 1) % mod;
}
}
LL fac(int x) {
return _fac[x];
}
LL ifac(int x) {
return _ifac[x];
}
LL C(int n, int m) {
if(m == 0) return 1;
if(m < 0 || m > n) return 0;
return fac(n) * ifac(n - m) % mod * ifac(m) % mod;
}
//模板结束~~~~~~~
void solve() {
ini_comb(1e6);
int n, m; cin >> n >> m;
swap(n, m);
int t = m - n;
V<int> g(n + 1), ex(n + 1);
for(int i = 1; i <= n; i++) {
g[i] = 1LL * power(i - 1, i) * inv(fac(i)) % mod;
}
g[0] = 1;
// assert(g[2] == mod - qpow(2, mod - 2));
// cout << "g : " << endl;
// for(int i = 0; i <= n; i++) cout << g[i] << " ";
// cout << endl;
for(int i = 0; i <= n; i++) {
ex[i] = 1LL * power(i + t, i) * inv(fac(i)) % mod;
}
// assert(ex[2] == 9 * qpow(2, mod - 2) % mod);
// cout << "ex : " << endl;
// for(int i = 0; i <= n; i++) cout << ex[i] << " ";
// cout << endl;
// cout << "inv(g) : " << endl;
V<int> test = inverse(g);
// for(int i = 0; i <= n; i++) cout << test[i] << " ";
// cout << endl;
V<int> test_h = ex * test;
// cout << "ex / g : " << endl;
// for(int i = 0; i <= n; i++) cout << test_h[i] << " ";
// cout << endl;
cout << 1LL * test_h[n] * fac(n) % mod;
}
int32_t main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin.exceptions(cin.failbit);
int t = 1;
//cin >> t;
while (t--)
solve();
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 8ms
memory: 18800kb
input:
4 3
output:
50
result:
ok single line: '50'
Test #2:
score: 0
Accepted
time: 8ms
memory: 18752kb
input:
10 1
output:
10
result:
ok single line: '10'
Test #3:
score: 0
Accepted
time: 7ms
memory: 19020kb
input:
2 2
output:
3
result:
ok single line: '3'
Test #4:
score: 0
Accepted
time: 7ms
memory: 18952kb
input:
1 1
output:
1
result:
ok single line: '1'
Test #5:
score: 0
Accepted
time: 9ms
memory: 18776kb
input:
277 277
output:
124662617
result:
ok single line: '124662617'
Test #6:
score: 0
Accepted
time: 8ms
memory: 18768kb
input:
426 1
output:
426
result:
ok single line: '426'
Test #7:
score: 0
Accepted
time: 8ms
memory: 18932kb
input:
200000 1
output:
200000
result:
ok single line: '200000'
Test #8:
score: 0
Accepted
time: 223ms
memory: 33168kb
input:
200000 200000
output:
950017432
result:
ok single line: '950017432'
Test #9:
score: 0
Accepted
time: 105ms
memory: 26212kb
input:
200000 100000
output:
280947286
result:
ok single line: '280947286'
Test #10:
score: 0
Accepted
time: 108ms
memory: 25908kb
input:
200000 84731
output:
211985425
result:
ok single line: '211985425'
Test #11:
score: 0
Accepted
time: 113ms
memory: 26720kb
input:
200000 124713
output:
716696526
result:
ok single line: '716696526'
Test #12:
score: 0
Accepted
time: 54ms
memory: 22548kb
input:
129179 49655
output:
506429515
result:
ok single line: '506429515'
Test #13:
score: 0
Accepted
time: 37ms
memory: 20856kb
input:
87518 26040
output:
808454539
result:
ok single line: '808454539'
Test #14:
score: 0
Accepted
time: 17ms
memory: 19660kb
input:
178355 10116
output:
361555714
result:
ok single line: '361555714'
Test #15:
score: 0
Accepted
time: 8ms
memory: 18804kb
input:
2 1
output:
2
result:
ok single line: '2'
Test #16:
score: 0
Accepted
time: 46ms
memory: 22556kb
input:
192733 52550
output:
67181038
result:
ok single line: '67181038'
Test #17:
score: 0
Accepted
time: 54ms
memory: 22360kb
input:
76689 36632
output:
717949287
result:
ok single line: '717949287'
Test #18:
score: 0
Accepted
time: 8ms
memory: 18916kb
input:
200000 9
output:
158524471
result:
ok single line: '158524471'
Test #19:
score: 0
Accepted
time: 225ms
memory: 33216kb
input:
200000 199998
output:
879727659
result:
ok single line: '879727659'
Test #20:
score: 0
Accepted
time: 8ms
memory: 18680kb
input:
199952 1
output:
199952
result:
ok single line: '199952'
Test #21:
score: 0
Accepted
time: 223ms
memory: 33156kb
input:
199947 199947
output:
339118685
result:
ok single line: '339118685'
Test #22:
score: 0
Accepted
time: 111ms
memory: 26080kb
input:
199956 99978
output:
135867461
result:
ok single line: '135867461'
Test #23:
score: 0
Accepted
time: 6ms
memory: 18780kb
input:
2 2
output:
3
result:
ok single line: '3'
Test #24:
score: 0
Accepted
time: 11ms
memory: 18908kb
input:
10 3
output:
968
result:
ok single line: '968'
Test #25:
score: 0
Accepted
time: 8ms
memory: 19040kb
input:
10 5
output:
87846
result:
ok single line: '87846'
Test #26:
score: 0
Accepted
time: 11ms
memory: 18724kb
input:
10 9
output:
428717762
result:
ok single line: '428717762'
Test #27:
score: 0
Accepted
time: 11ms
memory: 18832kb
input:
279 166
output:
945780025
result:
ok single line: '945780025'
Test #28:
score: 0
Accepted
time: 6ms
memory: 18916kb
input:
361 305
output:
926296326
result:
ok single line: '926296326'
Test #29:
score: 0
Accepted
time: 9ms
memory: 18724kb
input:
305 262
output:
465560336
result:
ok single line: '465560336'