QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#778673 | #7923. Ferris Wheel | danielz | WA | 1243ms | 98444kb | C++20 | 9.8kb | 2024-11-24 15:43:37 | 2024-11-24 15:43:37 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
/**
* Author: Andrew He
* Description: FFT/NTT, polynomial mod/log/exp
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf
* Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf
* For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$.
*/
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
namespace fft {
// NTT
const int mod = 998244353, g = 3;
// For p < 2^30 there is also (5 << 25, 3), (7 << 26, 3),
// (479 << 21, 3) and (483 << 21, 5). Last two are > 10^9.
struct num { /// start-hash
int v;
num(ll v_ = 0) : v(int(v_ % mod)) { if (v<0) v+=mod; }
explicit operator int() const { return v; }
};
inline num operator+(num a,num b){return num(a.v+b.v);}
inline num operator-(num a,num b){return num(a.v+mod-b.v);}
inline num operator*(num a,num b){return num(1ll*a.v*b.v);}
inline num pow(num a, int b) {
num r = 1;
do{if(b&1)r=r*a;a=a*a;}while(b>>=1);
return r;
}
inline num inv(num a) { return pow(a, mod-2); }
/// end-hash
using vn = vector<num>;
vi rev({0, 1});
vn rt(2, num(1)), fa, fb;
inline void init(int n) { /// start-hash
if (n <= (int)(rt).size()) return;
rev.resize(n);
for (int i = 0; i < (n); ++i) rev[i] = (rev[i>>1] | ((i&1)*n)) >> 1;
rt.reserve(n);
for (int k = (int)(rt).size(); k < n; k *= 2) {
rt.resize(2*k);
num z = pow(num(g), (mod-1)/(2*k)); // NTT
for (int i = k/2; i < (k); ++i) rt[2*i] = rt[i], rt[2*i+1] = rt[i]*z;
}
} /// end-hash
inline void fft(vector<num> &a, int n) { /// start-hash
init(n);
int s = __builtin_ctz((int)(rev).size()/n);
for (int i = 0; i < (n); ++i) if (i < rev[i]>>s) swap(a[i], a[rev[i]>>s]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) for (int j = 0; j < (k); ++j) {
num t = rt[j+k] * a[i+j+k];
a[i+j+k] = a[i+j] - t;
a[i+j] = a[i+j] + t;
}
} /// end-hash
// Complex/NTT
vn multiply(vn a, vn b) { /// start-hash
int s = (int)(a).size() + (int)(b).size() - 1;
if (s <= 0) return {};
int L = s > 1 ? 32 - __builtin_clz(s-1) : 0, n = 1 << L;
a.resize(n), b.resize(n);
fft(a, n);
fft(b, n);
num d = inv(num(n));
for (int i = 0; i < (n); ++i) a[i] = a[i] * b[i] * d;
reverse(a.begin()+1, a.end());
fft(a, n);
a.resize(s);
return a;
} /// end-hash
// Complex/NTT power-series inverse
// Doubles b as b[:n] = (2 - a[:n] * b[:n/2]) * b[:n/2]
vn inverse(const vn& a) { /// start-hash
if (a.empty()) return {};
vn b({inv(a[0])});
b.reserve(2*a.size());
while ((int)(b).size() < (int)(a).size()) {
int n = 2*(int)(b).size();
b.resize(2*n, 0);
if ((int)(fa).size() < 2*n) fa.resize(2*n);
fill(fa.begin(), fa.begin()+2*n, 0);
copy(a.begin(), a.begin()+min(n,(int)(a).size()), fa.begin());
fft(b, 2*n);
fft(fa, 2*n);
num d = inv(num(2*n));
for (int i = 0; i < (2*n); ++i) b[i] = b[i] * (2 - fa[i] * b[i]) * d;
reverse(b.begin()+1, b.end());
fft(b, 2*n);
b.resize(n);
}
b.resize(a.size());
return b;
} /// end-hash
} // namespace fft
// For multiply_mod, use num = modnum, poly = vector<num>
using fft::num;
using poly = fft::vn;
using fft::multiply;
using fft::inverse;
/// start-hash
poly& operator+=(poly& a, const poly& b) {
if ((int)(a).size() < (int)(b).size()) a.resize(b.size());
for (int i = 0; i < ((int)(b).size()); ++i) a[i]=a[i]+b[i];
return a;
}
poly operator+(const poly& a, const poly& b) { poly r=a; r+=b; return r; }
poly& operator-=(poly& a, const poly& b) {
if ((int)(a).size() < (int)(b).size()) a.resize(b.size());
for (int i = 0; i < ((int)(b).size()); ++i) a[i]=a[i]-b[i];
return a;
}
poly operator-(const poly& a, const poly& b) { poly r=a; r-=b; return r; }
poly operator*(const poly& a, const poly& b) {
// TODO: small-case?
return multiply(a, b);
}
poly& operator*=(poly& a, const poly& b) {return a = a*b;}
/// end-hash
poly& operator*=(poly& a, const num& b) { // Optional
for (auto &x : a) x = x * b;
return a;
}
poly operator*(const poly& a, const num& b) { poly r=a; r*=b; return r; }
// Polynomial floor division; no leading 0's plz
poly operator/(poly a, poly b) { /// start-hash
if ((int)(a).size() < (int)(b).size()) return {};
int s = (int)(a).size()-(int)(b).size()+1;
reverse(a.begin(), a.end());
reverse(b.begin(), b.end());
a.resize(s);
b.resize(s);
a = a * inverse(std::move(b));
a.resize(s);
reverse(a.begin(), a.end());
return a;
} /// end-hash
poly& operator/=(poly& a, const poly& b) {return a = a/b;}
poly& operator%=(poly& a, const poly& b) { /// start-hash
if ((int)(a).size() >= (int)(b).size()) {
poly c = (a / b) * b;
a.resize((int)(b).size()-1);
for (int i = 0; i < ((int)(a).size()); ++i) a[i] = a[i]-c[i];
}
return a;
} /// end-hash
poly operator%(const poly& a, const poly& b) { poly r=a; r%=b; return r; }
// Log/exp/pow
poly deriv(const poly& a) { /// start-hash
if (a.empty()) return {};
poly b((int)(a).size()-1);
for (int i = 1; i < ((int)(a).size()); ++i) b[i-1]=a[i]*i;
return b;
} /// end-hash
poly integ(const poly& a) { /// start-hash
poly b((int)(a).size()+1);
b[1]=1; // mod p
for (int i = 2; i < ((int)(b).size()); ++i) b[i]=b[fft::mod%i]*(-fft::mod/i); // mod p
for (int i = 1; i < ((int)(b).size()); ++i) b[i]=a[i-1]*b[i]; // mod p
//rep(i,1,sz(b)) b[i]=a[i-1]*inv(num(i)); // else
return b;
} /// end-hash
poly log(const poly& a) { // a[0] == 1 /// start-hash
poly b = integ(deriv(a)*inverse(a));
b.resize(a.size());
return b;
} /// end-hash
poly exp(const poly& a) { // a[0] == 0 /// start-hash
poly b(1,num(1));
if (a.empty()) return b;
while ((int)(b).size() < (int)(a).size()) {
int n = min((int)(b).size() * 2, (int)(a).size());
b.resize(n);
poly v = poly(a.begin(), a.begin() + n) - log(b);
v[0] = v[0]+num(1);
b *= v;
b.resize(n);
}
return b;
} /// end-hash
poly pow(const poly& a, int m) { // m >= 0 /// start-hash
poly b(a.size());
if (!m) { b[0] = 1; return b; }
int p = 0;
while (p<(int)(a).size() && a[p].v==0) ++p;
if (1ll*m*p >= (int)(a).size()) return b;
num mu = pow(a[p], m), di = inv(a[p]);
poly c((int)(a).size() - m*p);
for (int i = 0; i < ((int)(c).size()); ++i) c[i] = a[i+p] * di;
c = log(c);
for (auto &v : c) v = v * m;
c = exp(c);
for (int i = 0; i < ((int)(c).size()); ++i) b[i+m*p] = c[i] * mu;
return b;
} /// end-hash
// Multipoint evaluation/interpolation
/// start-hash
vector<num> eval(const poly& a, const vector<num>& x) {
int n=(int)(x).size();
if (!n) return {};
vector<poly> up(2*n);
for (int i = 0; i < (n); ++i) up[i+n] = poly({0-x[i], 1});
for (int i = (n) - 1; i >= 1; --i) up[i] = up[2*i]*up[2*i+1];
vector<poly> down(2*n);
down[1] = a % up[1];
for (int i = 2; i < (2*n); ++i) down[i] = down[i/2] % up[i];
vector<num> y(n);
for (int i = 0; i < (n); ++i) y[i] = down[i+n][0];
return y;
} /// end-hash
/// start-hash
poly interp(const vector<num>& x, const vector<num>& y) {
int n=(int)(x).size();
assert(n);
vector<poly> up(n*2);
for (int i = 0; i < (n); ++i) up[i+n] = poly({0-x[i], 1});
for (int i = (n) - 1; i >= 1; --i) up[i] = up[2*i]*up[2*i+1];
vector<num> a = eval(deriv(up[1]), x);
vector<poly> down(2*n);
for (int i = 0; i < (n); ++i) down[i+n] = poly({y[i]*inv(a[i])});
for (int i = (n) - 1; i >= 1; --i) down[i] = down[i*2] * up[i*2+1] + down[i*2+1] * up[i*2];
return down[1];
} /// end-hash
using ll = long long;
const ll inf = 1e18;
using namespace std;
using ll = long long;
template <int M>
struct mint {
ll v = 0;
mint() {}
mint(ll v) { this->v = (v % M + M) % M; }
mint operator+(const mint &o) const { return v + o.v; }
mint& operator+=(const mint& o) { v = (v + o.v) % M; return *this; }
mint operator*(const mint &o) const { return v * o.v; }
mint& operator*=(const mint& o) { v = (v * o.v) % M; return *this; }
mint operator-() const { return mint{0} - *this; }
mint operator-(const mint &o) const { return v - o.v; }
mint& operator-=(const mint& o) { mint t = v - o.v; v = t.v; return *this; }
mint exp(int y) const { mint r = 1, x = v; for (y <<= 1; y >>= 1; x = x * x) if (y & 1) r = r * x; return r; }
mint operator/(mint o) { return *this * o.inv(); }
mint inv() const { assert(v); return exp(M - 2); }
friend istream& operator>>(istream& s, mint& v) { s >> v.v; return s; }
friend ostream& operator<<(ostream& s, const mint& v) { s << v.v; return s; }
};
using namespace std;
template <int A, int M>
struct Combo {
mint<M> F[A], F_i[A];
Combo() { F[0] = F_i[0] = 1; for (int i = 1; i < A; i++) F_i[i] = (F[i] = F[i - 1] * i).inv(); }
mint<M> C(int n, int k) { return n < k ? 0 : F[n] * F_i[n - k] * F_i[k]; }
};
const int M = 998244353;
const int N = 6e6 + 1;
poly f, g, gf;
Combo<N, M> C;
mint<M> ctln(int n) { return C.C(2 * n, n) / (n + 1); }
int c[N];
int main() {
int n; mint<M> k; cin >> n >> k;
f.resize(n + 1), g.resize(n + 2);
g[1] = k.v;
// gcd has to be either 2 * n or <= n
for (int i = 1; 2 * i <= n; i++) {
f[2 * i] = (ctln(i - 1) * k * (k - 1).exp(i - 1)).v;
g[i * 2 + 1] = (C.C(2 * i, i) * k * (k - 1).exp(i)).v;
}
gf = (f * inverse(f * -1 + poly{1}) + poly{1});
// cout << gf.size() << endl;
// for (int i = 0; i < gf.size(); i++) cout << gf[i].v << " ";
// cout << endl;
gf *= (g + poly{1}); // [i] -> # of colorings for RBS + extra stuff of len i
for (int i = 1; i <= 2 * n; i++) ++c[gcd(i, 2 * n)];
gf.resize(2 * n + 1);
for (int i = 1; i <= n; i++) gf[2 * n] = gf[2 * n] + (mint<M>{i} / (2 * n - i) * C.C(2 * n - i, n) * (k - 1).exp(n - i) * k.exp(i)).v;
cout << gf[2 * n].v << endl;
mint<M> r = 0;
for (int i = 1; i <= 2 * n; i++) {
r += gf[i].v * c[i];
}
cout << r / (2 * n) << endl;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 1243ms
memory: 98444kb
input:
3 2
output:
20 6
result:
wrong answer 1st lines differ - expected: '6', found: '20'