QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#420466#8715. 放苹果Scintilla#AC ✓378ms81848kbC++2013.3kb2024-05-24 18:52:042024-05-24 18:52:04

Judging History

This is the latest submission verdict.

  • [2024-05-24 18:52:04]
  • Judged
  • Verdict: AC
  • Time: 378ms
  • Memory: 81848kb
  • [2024-05-24 18:52:04]
  • Submitted

answer

#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <bitset>
#include <vector>
#include <random>
#include <cassert>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <functional>
#include <unordered_map>

using namespace std;

#define rep(i, s, e) for (int i = s; i <= e; ++i)
#define per(i, s, e) for (int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)
#define pv(a) cout << #a << " = " << a << endl
#define pa(a, l, r) cout << #a " : "; rep(_, l, r) cout << a[_] << " \n"[_ == r]

const int P = 998244353;

const int N = 2e5 + 10;

int read() {
  int x = 0, f = 1; char c = getchar();
  for (; c < '0' || c > '9'; c = getchar()) if (c == '-') f = -1;
  for (; c >= '0' && c <= '9'; c = getchar()) x = x * 10 + (c - 48);
  return x * f;
}

int inc(int a, int b) { return (a += b) >= P ? a - P : a; }
int dec(int a, int b) { return (a -= b) < 0 ? a + P : a; }
int mul(int a, int b) { return 1ll * a * b % P; }
void add(int &a, int b) { (a += b) >= P ? a -= P : 1; }
void sub(int &a, int b) { (a -= b) < 0 ? a += P : 1; }
int sgn(int x) { return x & 1 ? P - 1 : 1; }
int qpow(int a, int b) { int res = 1; for (; b; b >>= 1, a = mul(a, a)) if (b & 1) res = mul(res, a); return res; }

int fac[N], inv[N], finv[N];
void init(int w) {
  fac[0] = inv[1] = finv[0] = 1;
  rep(i, 1, w) fac[i] = mul(fac[i - 1], i);
  rep(i, 2, w) inv[i] = P - mul(P / i, inv[P % i]);
  rep(i, 1, w) finv[i] = mul(finv[i - 1], inv[i]);
}

int C(int a, int b) {
  if (a < b || b < 0) return 0;
  return mul(fac[a], mul(finv[b], finv[a - b]));
}

namespace Poly {

  const int ng = 3;
  const int ngi = (P + 1) / 3;

  const int L = 20;
  const int N = 1 << L;

  int w[N], pinv[N];

  void init_w() {
    w[0] = pinv[1] = 1;
    rep(i, 0, L - 1) {
      w[1 << i] = qpow(ng, (P - 1) >> (i + 2));
    }
    rep(i, 2, N - 1) {
      w[i] = mul(w[i & -i], w[i & i - 1]);
      pinv[i] = P - mul(P / i, pinv[P % i]);
    }
  }

  void dif(int *f, int n) {
    int t, k, *tw, *tf, *pf;
    for (k = n >> 1; k; k >>= 1) {
      for (tw = w, tf = f; tf != f + n; ++tw, tf += k << 1) {
        for (pf = tf; pf != tf + k; ++pf) {
          t = mul(pf[k], *tw);
          pf[k] = dec(*pf, t), add(*pf, t);
        }
      }
    }
  }

  void dit(int *f, int n) {
    int t, k, *tw, *tf, *pf;
    for (k = 1; k < n; k <<= 1) {
      for (tw = w, tf = f; tf != f + n; ++tw, tf += k << 1) {
        for (pf = tf; pf != tf + k; ++pf) {
          t = pf[k];
          pf[k] = mul(dec(*pf, t), *tw), add(*pf, t);
        }
      }
    }
    int iv = P - (P - 1) / n;
    rep(i, 0, n - 1) f[i] = mul(f[i], iv);
    reverse(f + 1, f + n);
  }

  int f0[N];

  struct poly {

    vector <int> a;

    int size() { return a.size(); }
    int& operator [] (int i) { return a[i]; }
    void clear() { a.clear(); }
    void resize(int n) { a.resize(n); }

    poly(int n = 0, int x = 0) { a.resize(n, x); }
    poly(const vector <int> &o) { a = o; }
    poly(const poly &o) { a = o.a; }

    void operator += (poly b) {
      int n = size(), m = b.size();
      resize(max(n, m));
      rep(i, 0, m - 1) add(a[i], b[i]);
    }

    friend poly operator + (poly a, poly b) {
      a += b;
      return a;
    }

    void operator -= (poly b) {
      int n = size(), m = b.size();
      resize(max(n, m));
      rep(i, 0, m - 1) sub(a[i], b[i]);
    }

    friend poly operator - (poly a, poly b) {
      a -= b;
      return a;
    }

    void operator <<= (int k) {
      int n = size();
      resize(n + k);
      per(i, n - 1, 0) a[i + k] = a[i];
      rep(i, 0, k - 1) a[i] = 0;
    }

    friend poly operator << (poly a, int k) {
      a <<= k;
      return a;
    }

    void operator *= (int b) {
      int n = size();
      rep(i, 0, n - 1) a[i] = mul(a[i], b);
    }

    friend poly operator * (poly a, int b) {
      a *= b;
      return a;
    }

    void operator /= (int b) {
      int n = size(), iv = qpow(b, P - 2);
      rep(i, 0, n - 1) a[i] = mul(a[i], iv);
    }

    friend poly operator / (poly a, int b) {
      a /= b;
      return a;
    }

    void operator >>= (int k) {
      int n = size();
      if (n <= k) {
        *this = poly();
        return;
      }
      rep(i, k, n - 1) a[i - k] = a[i];
      resize(n - k);
    }

    friend poly operator >> (poly a, int k) {
      a >>= k;
      return a;
    }

    poly diff() {
      int n = size();
      if (!n) return poly();
      poly res(n - 1);
      rep(i, 1, n - 1) res[i - 1] = mul(a[i], i);
      return res;
    }

    poly intg() {
      int n = size();
      poly res(n + 1);
      rep(i, 0, n - 1) res[i + 1] = mul(a[i], pinv[i + 1]);
      return res;
    }

    void dft() {
      int n = size();
      assert(n == (n & -n));
      rep(i, 0, n - 1) f0[i] = a[i];
      dif(f0, n);
      rep(i, 0, n - 1) a[i] = f0[i];
    }

    void idft() {
      int n = size();
      assert(n == (n & -n));
      rep(i, 0, n - 1) f0[i] = a[i];
      dit(f0, n);
      rep(i, 0, n - 1) a[i] = f0[i];
    }

    void operator *= (poly b) {
      int n = size(), m = b.size();
      int lim = 1;
      while (lim < n + m - 1) lim <<= 1;
      resize(lim), dft();
      b.resize(lim), b.dft();
      rep(i, 0, lim - 1) a[i] = mul(a[i], b[i]);
      idft(), resize(n + m - 1);
    }

    friend poly operator * (poly a, poly b) {
      a *= b;
      return a;
    }

    void operator /= (poly b) {
      int n = size(), m = b.size();
      if (n < m) {
        resize(0);
        return;
      }
      reverse(a.begin(), a.end());
      reverse(b.a.begin(), b.a.end());
      resize(n - m + 1), b.resize(n - m + 1);
      *this *= b.inv(), resize(n - m + 1);
      reverse(a.begin(), a.end());
    }

    friend poly operator / (poly a, poly b) {
      a /= b;
      return a;
    }

    void operator %= (poly b) {
      int n = size(), m = b.size();
      if (n < m) return;
      poly q = *this / b;
      *this -= q * b, resize(m - 1);
    }

    friend poly operator % (poly a, poly b) {
      a %= b;
      return a;
    }

    pair <poly, poly> div(poly b) {
      int n = size(), m = b.size();
      if (n < m) return make_pair(poly(), *this);
      poly q = *this / b, r = *this - b * q;
      r.resize(m - 1);
      return make_pair(q, r);
    }

    poly inv() {
      int n = size();
      assert(n);
      assert(a[0]);
      poly res(1), cp;
      res[0] = qpow(a[0], P - 2);
      for (int o = 2; o < (n << 1); o <<= 1) {
        cp = a, cp.resize(o), cp.resize(o << 1);
        res.resize(o << 1);
        res.dft(), cp.dft();
        for (int i = 0; i < o << 1; ++i) {
          res[i] = mul(res[i], dec(2, mul(res[i], cp[i])));
        }
        res.idft(), res.resize(o);
      }
      res.resize(n);
      return res;
    }

    poly ln() {
      int n = size();
      assert(n);
      assert(a[0] == 1);
      poly res = (diff() * inv()).intg();
      res.resize(n);
      return res;
    }
    
    poly exp() {
      int n = size();
      assert(n);
      assert(!a[0]);
      poly res(1), fln, cp;
      res[0] = 1;
      for (int o = 2; o < (n << 1); o <<= 1) {
        cp = a, cp.resize(o), cp.resize(o << 1);
        res.resize(o), fln = res.ln(), fln.resize(o << 1);
        res.resize(o << 1);
        res.dft(), fln.dft(), cp.dft();
        for (int i = 0; i < o << 1; ++i) {
          res[i] = mul(res[i], dec(1, dec(fln[i], cp[i])));
        }
        res.idft(), res.resize(o);
      }
      res.resize(n);
      return res;
    }

    poly pow(int k) {
      poly cp = *this;
      assert(cp.size());
      assert(cp.a[0]);
      int c = cp.a[0];
      cp /= c;
      cp = (cp.ln() * k).exp();
      cp *= qpow(c, k);
      return cp;
    }

    poly sqrt() {
      int n = size();
      assert(n);
      assert(a[0] == 1);
      poly res(1), fiv, cp;
      res[0] = a[0];
      for (int o = 2; o < (n << 1); o <<= 1) {
        cp = a, cp.resize(o), cp.resize(o << 1);
        res.resize(o), fiv = res.inv(), fiv.resize(o << 1);
        res.resize(o << 1);
        res.dft(), fiv.dft(), cp.dft();
        for (int i = 0; i < o << 1; ++i) {
          res[i] = mul(inc(mul(res[i], res[i]), cp[i]), mul((P + 1) / 2, fiv[i]));
        }
        res.idft(), res.resize(o);
      }
      res.resize(n);
      return res;
    }

  } ;

  namespace linear_recur {

    int linear_recur(poly f, poly a, int n) {
      int k = f.size(), ans = 0;
      poly g(k + 1), cur(2), res(1);
      rep(i, 0, k - 1) g[i] = dec(0, f[k - 1 - i]);
      g[k] = cur[1] = res[0] = 1;
      for (int b = n; b; b >>= 1, cur = cur * cur % g) {
        if (b & 1) res = res * cur % g;
      }
      rep(i, 0, k - 1) add(ans, mul(res[i], a[i]));
      return ans;
    }

  }

  namespace evaluation_and_interpolation {

    poly MulT(poly a, poly b) {
      int n = a.size(), m = b.size();
      reverse(b.a.begin(), b.a.end());
      b *= a, a.resize(n - m + 1);
      rep(i, 0, n - m) a[i] = b[i + m - 1];
      return a;
    }

    // n <= _N
    const int _L = 17;
    const int _N = 1 << _L;

    poly M[_N << 2], MD, Q[_N << 2], F[_N << 2], G[_N << 2];

    #define ls (u << 1)
    #define rs (u << 1 | 1)
    #define mid ((l + r) >> 1)

    vector <int> multi_evaluation(poly f, vector <int> a) {
      int n = f.size(), m = a.size();
      if (n < m) f.resize(n = m);
      if (n > m) a.resize(n);
      vector <int> ans;
      auto dfs1 = [&](auto self, int u, int l, int r) {
        if (l == r) {
          Q[u] = poly(vector <int> ({ 1, dec(0, a[l]) }));
          return;
        }
        self(self, ls, l, mid), self(self, rs, mid + 1, r);
        Q[u] = Q[ls] * Q[rs];
      } ;
      auto dfs2 = [&](auto self, int u, int l, int r) {
        if (l >= m) return;
        if (l == r) {
          ans.emplace_back(F[u][0]);
          return;
        }
        F[ls] = MulT(F[u], Q[rs]), F[rs] = MulT(F[u], Q[ls]);
        self(self, ls, l, mid), self(self, rs, mid + 1, r);
      } ;
      dfs1(dfs1, 1, 0, n - 1);
      poly t = Q[1].inv();
      t.resize(n);
      reverse(t.a.begin(), t.a.end());
      t *= f;
      F[1].a.resize(n);
      rep(i, 0, n - 1) F[1][i] = t[i + n - 1];
      dfs2(dfs2, 1, 0, n - 1);
      assert(ans.size() == m);
      return ans;
    }

    poly interpolation(vector <int> a, vector <int> b) {
      int n = a.size();
      auto dfs0 = [&](auto self, int u, int l, int r) {
        if (l == r) {
          M[u] = poly(vector <int> ({ dec(0, a[l]), 1 }));
          return;
        }
        self(self, ls, l, mid), self(self, rs, mid + 1, r);
        M[u] = M[ls] * M[rs];
      } ;
      auto dfs1 = [&](auto self, int u, int l, int r) {
        if (l == r) {
          Q[u] = poly(vector <int> ({ 1, dec(0, a[l]) }));
          return;
        }
        self(self, ls, l, mid), self(self, rs, mid + 1, r);
        Q[u] = Q[ls] * Q[rs];
      } ;
      auto dfs2 = [&](auto self, int u, int l, int r) {
        if (l == r) {
          G[u] = poly(vector <int> ({ mul(b[l], qpow(F[u][0], P - 2)) }));
          return;
        }
        F[ls] = MulT(F[u], Q[rs]), F[rs] = MulT(F[u], Q[ls]);
        self(self, ls, l, mid), self(self, rs, mid + 1, r);
        G[u] = G[ls] * M[rs] + M[ls] * G[rs];
      } ;
      dfs0(dfs0, 1, 0, n - 1);
      dfs1(dfs1, 1, 0, n - 1);
      MD.resize(n);
      rep(i, 0, n - 1) MD[i] = mul(i + 1, M[1][i + 1]);
      poly t = Q[1].inv();
      t.resize(n);
      reverse(t.a.begin(), t.a.end());
      t *= MD;
      F[1].resize(n);
      rep(i, 0, n - 1) F[1][i] = t[i + n - 1];
      dfs2(dfs2, 1, 0, n - 1);
      return G[1];
    }

    #undef ls
    #undef rs
    #undef mid

  }

}

using Poly::init_w;
using Poly::poly;
using Poly::linear_recur::linear_recur;
using Poly::evaluation_and_interpolation::multi_evaluation;
using Poly::evaluation_and_interpolation::interpolation;

int n, m;

poly f2[N];
poly fk2(int k) {
  // cout << "fk2 k = " << k << endl;
  assert(k >= 0);
  if (!k) return poly(1, 1);
  else if (k == 1) {
    poly res(2);
    res[0] = m, res[1] = P - 1;
    return res;
  }
  if (f2[k].size()) return f2[k];
  return f2[k] = fk2(k / 2) * fk2(k - k / 2);
}

poly solve(int l, int r) {
  if (l == r) {
    return poly(1, mul(min(l, n - l), C(n, l)));
  }
  int mid = (l + r) >> 1;
  // cout << "----- l, r, mid = " << l << ' ' << r << ' ' << mid << endl;
  poly v = fk2(mid - l + 1);
  // pv(v.size());
  // pa(v, 0, v.size() - 1);
  poly res = (solve(l, mid) << (r - mid)) + v * solve(mid + 1, r);
  // pv(res.size());
  // pa(res, 0, res.size() - 1);
  return res;
}

poly calc() {
  poly f(n + 2), g(n + 2);
  rep(i, 0, n + 1) {
    f[i] = mul(qpow(m, i), finv[i]);
    g[i] = finv[i];
  }
  // pa(f, 0, n + 1);
  // pa(g, 0, n + 1);
  f >>= 1, g >>= 1;
  f *= g.inv();
  rep(i, 0, n) f[i] = mul(f[i], fac[i]);
  return f;
}

int main() {
  init(N - 5);
  init_w();
  n = read(), m = read();
  poly res = solve(0, n);
  // pa(res, 0, n);
  poly pw = calc();
  // pa(pw, 0, n);
  sub(pw[0], 1);
  int ans = 0;
  rep(i, 0, n) {
    add(ans, mul(res[i], pw[i]));
  }
  // pv(ans);
  printf("%d\n", ans);
  return 0;
}

这程序好像有点Bug,我给组数据试试?

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 19ms
memory: 70436kb

input:

2 3

output:

8

result:

ok 1 number(s): "8"

Test #2:

score: 0
Accepted
time: 20ms
memory: 70172kb

input:

3 3

output:

36

result:

ok 1 number(s): "36"

Test #3:

score: 0
Accepted
time: 19ms
memory: 70148kb

input:

1 1

output:

0

result:

ok 1 number(s): "0"

Test #4:

score: 0
Accepted
time: 9ms
memory: 70156kb

input:

1 2

output:

0

result:

ok 1 number(s): "0"

Test #5:

score: 0
Accepted
time: 23ms
memory: 70236kb

input:

1 3

output:

0

result:

ok 1 number(s): "0"

Test #6:

score: 0
Accepted
time: 16ms
memory: 70228kb

input:

2 1

output:

0

result:

ok 1 number(s): "0"

Test #7:

score: 0
Accepted
time: 12ms
memory: 70212kb

input:

3 1

output:

0

result:

ok 1 number(s): "0"

Test #8:

score: 0
Accepted
time: 12ms
memory: 70328kb

input:

3719 101

output:

78994090

result:

ok 1 number(s): "78994090"

Test #9:

score: 0
Accepted
time: 19ms
memory: 70584kb

input:

2189 1022

output:

149789741

result:

ok 1 number(s): "149789741"

Test #10:

score: 0
Accepted
time: 19ms
memory: 70316kb

input:

2910 382012013

output:

926541722

result:

ok 1 number(s): "926541722"

Test #11:

score: 0
Accepted
time: 283ms
memory: 79644kb

input:

131072 3837829

output:

487765455

result:

ok 1 number(s): "487765455"

Test #12:

score: 0
Accepted
time: 368ms
memory: 81720kb

input:

183092 100000000

output:

231786691

result:

ok 1 number(s): "231786691"

Test #13:

score: 0
Accepted
time: 376ms
memory: 81848kb

input:

197291 937201572

output:

337054675

result:

ok 1 number(s): "337054675"

Test #14:

score: 0
Accepted
time: 378ms
memory: 81740kb

input:

200000 328194672

output:

420979346

result:

ok 1 number(s): "420979346"

Test #15:

score: 0
Accepted
time: 362ms
memory: 81800kb

input:

200000 1000000000

output:

961552572

result:

ok 1 number(s): "961552572"

Extra Test:

score: 0
Extra Test Passed