QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#778679#7923. Ferris WheeldanielzTL 1246ms99328kbC++209.8kb2024-11-24 15:45:092024-11-24 15:45:09

Judging History

你现在查看的是最新测评结果

  • [2024-11-24 15:45:09]
  • 评测
  • 测评结果:TL
  • 用时:1246ms
  • 内存:99328kb
  • [2024-11-24 15:45:09]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
/**
 * Author: Andrew He
 * Description: FFT/NTT, polynomial mod/log/exp
 * Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf
 * Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf
 * For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$.
 */
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
namespace fft {
// NTT
const int mod = 998244353, g = 3;
// For p < 2^30 there is also (5 << 25, 3), (7 << 26, 3),
// (479 << 21, 3) and (483 << 21, 5). Last two are > 10^9.
struct num { /// start-hash
 int v;
 num(ll v_ = 0) : v(int(v_ % mod)) { if (v<0) v+=mod; }
 explicit operator int() const { return v; }
};
inline num operator+(num a,num b){return num(a.v+b.v);}
inline num operator-(num a,num b){return num(a.v+mod-b.v);}
inline num operator*(num a,num b){return num(1ll*a.v*b.v);}
inline num pow(num a, int b) {
 num r = 1;
 do{if(b&1)r=r*a;a=a*a;}while(b>>=1);
 return r;
}
inline num inv(num a) { return pow(a, mod-2); }
/// end-hash
using vn = vector<num>;
vi rev({0, 1});
vn rt(2, num(1)), fa, fb;
inline void init(int n) { /// start-hash
 if (n <= (int)(rt).size()) return;
 rev.resize(n);
 for (int i = 0; i < (n); ++i) rev[i] = (rev[i>>1] | ((i&1)*n)) >> 1;
 rt.reserve(n);
 for (int k = (int)(rt).size(); k < n; k *= 2) {
  rt.resize(2*k);
  num z = pow(num(g), (mod-1)/(2*k)); // NTT
  for (int i = k/2; i < (k); ++i) rt[2*i] = rt[i], rt[2*i+1] = rt[i]*z;
 }
} /// end-hash
inline void fft(vector<num> &a, int n) { /// start-hash
 init(n);
 int s = __builtin_ctz((int)(rev).size()/n);
 for (int i = 0; i < (n); ++i) if (i < rev[i]>>s) swap(a[i], a[rev[i]>>s]);
 for (int k = 1; k < n; k *= 2)
  for (int i = 0; i < n; i += 2 * k) for (int j = 0; j < (k); ++j) {
   num t = rt[j+k] * a[i+j+k];
   a[i+j+k] = a[i+j] - t;
   a[i+j] = a[i+j] + t;
  }
} /// end-hash
// Complex/NTT
vn multiply(vn a, vn b) { /// start-hash
 int s = (int)(a).size() + (int)(b).size() - 1;
 if (s <= 0) return {};
 int L = s > 1 ? 32 - __builtin_clz(s-1) : 0, n = 1 << L;
 a.resize(n), b.resize(n);
 fft(a, n);
 fft(b, n);
 num d = inv(num(n));
 for (int i = 0; i < (n); ++i) a[i] = a[i] * b[i] * d;
 reverse(a.begin()+1, a.end());
 fft(a, n);
 a.resize(s);
 return a;
} /// end-hash
// Complex/NTT power-series inverse
// Doubles b as b[:n] = (2 - a[:n] * b[:n/2]) * b[:n/2]
vn inverse(const vn& a) { /// start-hash
 if (a.empty()) return {};
 vn b({inv(a[0])});
 b.reserve(2*a.size());
 while ((int)(b).size() < (int)(a).size()) {
  int n = 2*(int)(b).size();
  b.resize(2*n, 0);
  if ((int)(fa).size() < 2*n) fa.resize(2*n);
  fill(fa.begin(), fa.begin()+2*n, 0);
  copy(a.begin(), a.begin()+min(n,(int)(a).size()), fa.begin());
  fft(b, 2*n);
  fft(fa, 2*n);
  num d = inv(num(2*n));
  for (int i = 0; i < (2*n); ++i) b[i] = b[i] * (2 - fa[i] * b[i]) * d;
  reverse(b.begin()+1, b.end());
  fft(b, 2*n);
  b.resize(n);
 }
 b.resize(a.size());
 return b;
} /// end-hash
} // namespace fft
// For multiply_mod, use num = modnum, poly = vector<num>
using fft::num;
using poly = fft::vn;
using fft::multiply;
using fft::inverse;
/// start-hash
poly& operator+=(poly& a, const poly& b) {
 if ((int)(a).size() < (int)(b).size()) a.resize(b.size());
 for (int i = 0; i < ((int)(b).size()); ++i) a[i]=a[i]+b[i];
 return a;
}
poly operator+(const poly& a, const poly& b) { poly r=a; r+=b; return r; }
poly& operator-=(poly& a, const poly& b) {
 if ((int)(a).size() < (int)(b).size()) a.resize(b.size());
 for (int i = 0; i < ((int)(b).size()); ++i) a[i]=a[i]-b[i];
 return a;
}
poly operator-(const poly& a, const poly& b) { poly r=a; r-=b; return r; }
poly operator*(const poly& a, const poly& b) {
 // TODO: small-case?
 return multiply(a, b);
}
poly& operator*=(poly& a, const poly& b) {return a = a*b;}
/// end-hash
poly& operator*=(poly& a, const num& b) { // Optional
 for (auto &x : a) x = x * b;
 return a;
}
poly operator*(const poly& a, const num& b) { poly r=a; r*=b; return r; }
// Polynomial floor division; no leading 0's plz
poly operator/(poly a, poly b) { /// start-hash
 if ((int)(a).size() < (int)(b).size()) return {};
 int s = (int)(a).size()-(int)(b).size()+1;
 reverse(a.begin(), a.end());
 reverse(b.begin(), b.end());
 a.resize(s);
 b.resize(s);
 a = a * inverse(std::move(b));
 a.resize(s);
 reverse(a.begin(), a.end());
 return a;
} /// end-hash
poly& operator/=(poly& a, const poly& b) {return a = a/b;}
poly& operator%=(poly& a, const poly& b) { /// start-hash
 if ((int)(a).size() >= (int)(b).size()) {
  poly c = (a / b) * b;
  a.resize((int)(b).size()-1);
  for (int i = 0; i < ((int)(a).size()); ++i) a[i] = a[i]-c[i];
 }
 return a;
} /// end-hash
poly operator%(const poly& a, const poly& b) { poly r=a; r%=b; return r; }
// Log/exp/pow
poly deriv(const poly& a) { /// start-hash
 if (a.empty()) return {};
 poly b((int)(a).size()-1);
 for (int i = 1; i < ((int)(a).size()); ++i) b[i-1]=a[i]*i;
 return b;
} /// end-hash
poly integ(const poly& a) { /// start-hash
 poly b((int)(a).size()+1);
 b[1]=1; // mod p
 for (int i = 2; i < ((int)(b).size()); ++i) b[i]=b[fft::mod%i]*(-fft::mod/i); // mod p
 for (int i = 1; i < ((int)(b).size()); ++i) b[i]=a[i-1]*b[i]; // mod p
 //rep(i,1,sz(b)) b[i]=a[i-1]*inv(num(i)); // else
 return b;
} /// end-hash
poly log(const poly& a) { // a[0] == 1 /// start-hash
 poly b = integ(deriv(a)*inverse(a));
 b.resize(a.size());
 return b;
} /// end-hash
poly exp(const poly& a) { // a[0] == 0 /// start-hash
 poly b(1,num(1));
 if (a.empty()) return b;
 while ((int)(b).size() < (int)(a).size()) {
  int n = min((int)(b).size() * 2, (int)(a).size());
  b.resize(n);
  poly v = poly(a.begin(), a.begin() + n) - log(b);
  v[0] = v[0]+num(1);
  b *= v;
  b.resize(n);
 }
 return b;
} /// end-hash
poly pow(const poly& a, int m) { // m >= 0 /// start-hash
 poly b(a.size());
 if (!m) { b[0] = 1; return b; }
 int p = 0;
 while (p<(int)(a).size() && a[p].v==0) ++p;
 if (1ll*m*p >= (int)(a).size()) return b;
 num mu = pow(a[p], m), di = inv(a[p]);
 poly c((int)(a).size() - m*p);
 for (int i = 0; i < ((int)(c).size()); ++i) c[i] = a[i+p] * di;
 c = log(c);
 for (auto &v : c) v = v * m;
 c = exp(c);
 for (int i = 0; i < ((int)(c).size()); ++i) b[i+m*p] = c[i] * mu;
 return b;
} /// end-hash
// Multipoint evaluation/interpolation
/// start-hash
vector<num> eval(const poly& a, const vector<num>& x) {
 int n=(int)(x).size();
 if (!n) return {};
 vector<poly> up(2*n);
 for (int i = 0; i < (n); ++i) up[i+n] = poly({0-x[i], 1});
 for (int i = (n) - 1; i >= 1; --i) up[i] = up[2*i]*up[2*i+1];
 vector<poly> down(2*n);
 down[1] = a % up[1];
 for (int i = 2; i < (2*n); ++i) down[i] = down[i/2] % up[i];
 vector<num> y(n);
 for (int i = 0; i < (n); ++i) y[i] = down[i+n][0];
 return y;
} /// end-hash
/// start-hash
poly interp(const vector<num>& x, const vector<num>& y) {
 int n=(int)(x).size();
 assert(n);
 vector<poly> up(n*2);
 for (int i = 0; i < (n); ++i) up[i+n] = poly({0-x[i], 1});
 for (int i = (n) - 1; i >= 1; --i) up[i] = up[2*i]*up[2*i+1];
 vector<num> a = eval(deriv(up[1]), x);
 vector<poly> down(2*n);
 for (int i = 0; i < (n); ++i) down[i+n] = poly({y[i]*inv(a[i])});
 for (int i = (n) - 1; i >= 1; --i) down[i] = down[i*2] * up[i*2+1] + down[i*2+1] * up[i*2];
 return down[1];
} /// end-hash
using ll = long long;
const ll inf = 1e18;
using namespace std;
using ll = long long;
template <int M>
struct mint {
    ll v = 0;
    mint() {}
    mint(ll v) { this->v = (v % M + M) % M; }
    mint operator+(const mint &o) const { return v + o.v; }
    mint& operator+=(const mint& o) { v = (v + o.v) % M; return *this; }
    mint operator*(const mint &o) const { return v * o.v; }
    mint& operator*=(const mint& o) { v = (v * o.v) % M; return *this; }
    mint operator-() const { return mint{0} - *this; }
    mint operator-(const mint &o) const { return v - o.v; }
    mint& operator-=(const mint& o) { mint t = v - o.v; v = t.v; return *this; }
    mint exp(int y) const { mint r = 1, x = v; for (y <<= 1; y >>= 1; x = x * x) if (y & 1) r = r * x; return r; }
    mint operator/(mint o) { return *this * o.inv(); }
    mint inv() const { assert(v); return exp(M - 2); }
    friend istream& operator>>(istream& s, mint& v) { s >> v.v; return s; }
    friend ostream& operator<<(ostream& s, const mint& v) { s << v.v; return s; }
};
using namespace std;
template <int A, int M>
struct Combo {
  mint<M> F[A], F_i[A];
  Combo() { F[0] = F_i[0] = 1; for (int i = 1; i < A; i++) F_i[i] = (F[i] = F[i - 1] * i).inv(); }
  mint<M> C(int n, int k) { return n < k ? 0 : F[n] * F_i[n - k] * F_i[k]; }
};
const int M = 998244353;
const int N = 6e6 + 1;
poly f, g, gf;
Combo<N, M> C;
mint<M> ctln(int n) { return C.C(2 * n, n) / (n + 1); }
int c[N];
int main() {
    int n; mint<M> k; cin >> n >> k;
    f.resize(n + 1), g.resize(n + 2);
    g[1] = k.v;
    // gcd has to be either 2 * n or <= n
    for (int i = 1; 2 * i <= n; i++) {
        f[2 * i] = (ctln(i - 1) * k * (k - 1).exp(i - 1)).v;
        g[i * 2 + 1] = (C.C(2 * i, i) * k * (k - 1).exp(i)).v;
    }
    gf = (f * inverse(f * -1 + poly{1}) + poly{1});
    // cout << gf.size() << endl;
    // for (int i = 0; i < gf.size(); i++) cout << gf[i].v << " ";
    // cout << endl;
    gf *= (g + poly{1}); // [i] -> # of colorings for RBS + extra stuff of len i
    for (int i = 1; i <= 2 * n; i++) ++c[gcd(i, 2 * n)];
    gf.resize(2 * n + 1);
    for (int i = 1; i <= n; i++) gf[2 * n] = gf[2 * n] + (mint<M>{i} / (2 * n - i) * C.C(2 * n - i, n) * (k - 1).exp(n - i) * k.exp(i)).v;
    // cout << gf[2 * n].v << endl;
    mint<M> r = 0;
    for (int i = 1; i <= 2 * n; i++) {
        r += (ll)gf[i].v * c[i];
    }
    cout << r / (2 * n) << endl;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1246ms
memory: 99328kb

input:

3 2

output:

6

result:

ok single line: '6'

Test #2:

score: 0
Accepted
time: 1244ms
memory: 98308kb

input:

5 3

output:

372

result:

ok single line: '372'

Test #3:

score: 0
Accepted
time: 1242ms
memory: 99128kb

input:

2023 1126

output:

900119621

result:

ok single line: '900119621'

Test #4:

score: -100
Time Limit Exceeded

input:

2882880 2892778

output:


result: