QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#641919#433. 抽卡Elegia100 ✓1331ms22876kbC++1425.6kb2024-10-15 03:18:532024-10-15 03:19:01

Judging History

你现在查看的是最新测评结果

  • [2024-10-15 03:19:01]
  • 评测
  • 测评结果:100
  • 用时:1331ms
  • 内存:22876kb
  • [2024-10-15 03:18:53]
  • 提交

answer

#include <cstdio>
#include <cstring>

#include <algorithm>
#include <chrono>
#include <random>
#include <functional>
#include <vector>

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;

const int P = 998244353, R = 3;
const int BRUTE_N2_LIMIT = 50;

int mpow(int x, int k, int p = P) {
  int ret = 1;
  while (k) {
    if (k & 1)
      ret = ret * (ll) x % p;
    x = x * (ll) x % p;
    k >>= 1;
  }
  return ret;
}

int norm(int x) { return x >= P ? x - P : x; }

struct NumberTheory {

  typedef pair<int, int> _P2_Field;

  mt19937 rng;

  NumberTheory() : rng(chrono::steady_clock::now().time_since_epoch().count()) {}

  void _exGcd(int a, int b, int &x, int &y) {
    if (!b) {
      x = 1;
      y = 0;
      return;
    }
    _exGcd(b, a % b, y, x);
    y -= a / b * x;
  }

  int inv(int a, int p = P) {
    int x, y;
    _exGcd(a, p, x, y);
    if (x < 0)
      x += p;
    return x;
  }

  template<class Integer>
  bool quadRes(Integer a, Integer b) {
    if (a <= 1)
      return true;
    while (a % 4 == 0)
      a /= 4;
    if (a % 2 == 0)
      return (b % 8 == 1 || b % 8 == 7) == quadRes(a / 2, b);
    return ((a - 1) % 4 == 0 || (b - 1) % 4 == 0) == quadRes(b % a, a);
  }

  // assume p in prime, x in quadratic residue
  int sqrt(int x, int p = P) {
    if (p == 2 || x <= 1)
      return x;
    int w, v, k = (p + 1) / 2;
    do {
      w = rng() % p;
    } while (quadRes(v = int((w * (ll) w - x + p) % p), p));
    _P2_Field res(1, 0), a(w, 1);
    while (k) {
      if (k & 1)
        res = _P2_Field((res.first * (ll) a.first + res.second * (ll) a.second % p * v) % p,
                        (res.first * (ll) a.second + res.second * (ll) a.first) % p);
      if (k >>= 1)
        a = _P2_Field((a.first * (ll) a.first + a.second * (ll) a.second % p * v) % p,
                      (a.first * (ll) a.second << 1) % p);
    }
    return min(res.first, p - res.first);
  }

} nt;

template<class T, class Comp>
struct AdditionChain {
  int k;
  vector<T> prepare;
  T t, unit;
  Comp comp;

  AdditionChain(const T &t, const Comp &comp, int k, const T &unit = 1) : comp(comp), t(t), unit(unit), k(k),
                                                                          prepare(1U << k) {
    prepare[0] = unit;
    for (int i = 1; i < 1 << k; ++i)
      prepare[i] = comp(prepare[i - 1], t);
  }

  static AdditionChain fourRussians(const T &t, const Comp &comp, int lgn, const T &unit = 1) {
    lgn = max(lgn, 1);
    int k = 1, lglgn = 1;
    while (2 << lglgn <= lgn)
      ++lglgn;
    int w = lglgn / lgn;
    while (1 << k < w)
      ++k;
    return AdditionChain(t, comp, k, unit);
  }

  T pow(int n) const {
    if (n < 1 << k)
      return prepare[n];
    int r = n & ((1 << k) - 1);
    T step = pow(n >> k);
    for (int rep = 0; rep < k; ++rep)
      step = comp(step, step);
    return comp(step, prepare[r]);
  }
};

struct Simple {
  int n;
  vector<int> fac, ifac, inv;

  void build(int n) {
    this->n = n;
    fac.resize(n + 1);
    ifac.resize(n + 1);
    inv.resize(n + 1);
    fac[0] = 1;
    for (int x = 1; x <= n; ++x)
      fac[x] = fac[x - 1] * (ll) x % P;
    inv[1] = 1;
    for (int x = 2; x <= n; ++x)
      inv[x] = -(P / x) * (ll) inv[P % x] % P + P;
    ifac[0] = 1;
    for (int x = 1; x <= n; ++x)
      ifac[x] = ifac[x - 1] * (ll) inv[x] % P;
  }

  Simple() {
    build(1);
  }

  void check(int k) {
    int nn = n;
    if (k > nn) {
      while (k > nn)
        nn <<= 1;
      build(nn);
    }
  }

  int gfac(int k) {
    check(k);
    return fac[k];
  }

  int gifac(int k) {
    check(k);
    return ifac[k];
  }

  int ginv(int k) {
    check(k);
    return inv[k];
  }

  int binom(int n, int m) {
    if (m < 0 || m > n)
      return 0;
    return gfac(n) * (ll) gifac(m) % P * gifac(n - m) % P;
  }
} simp;

const int L2 = 11;

struct NTT {
  int L;
  int brev[1 << L2];
  vector<int> root;

  NTT() : L(-1) {
    for (int i = 1; i < (1 << L2); ++i)
      brev[i] = brev[i >> 1] >> 1 | ((i & 1) << (L2 - 1));
  }

  void prepRoot(int l) {
    L = l;
    root.resize(1 << L);
    int n = 1 << L;
    int primitive = mpow(R, (P - 1) >> L);
    root[0] = 1;
    for (int i = 1; i < n; ++i) root[i] = root[i - 1] * (ll) primitive % P;
  }

  void fft(int *a, int lgn, int d = 1) {
    if (L < lgn) prepRoot(lgn);
    int n = 1 << lgn;
    for (int i = 0; i < n; ++i) {
      int rev = (brev[i >> L2] | (brev[i & ((1 << L2) - 1)] << L2)) >> ((L2 << 1) - lgn);
      if (i < rev)
        swap(a[i], a[rev]);
    }
    int rt = d == 1 ? R : nt.inv(R);
    for (int k = L - 1, t = 1; t < n; t <<= 1, --k) {
      for (int i = 0; i < n; i += t << 1) {
        int *p1 = a + i, *p2 = a + i + t;
        for (int j = 0; j < t; ++j) {
          int x = p2[j] * (ll) root[j << k] % P;
          p2[j] = norm(p1[j] + P - x);
          p1[j] = norm(p1[j] + x);
        }
      }
    }
    if (d == -1) {
      reverse(a + 1, a + n);
      int nv = mpow(n, P - 2);
      for (int i = 0; i < n; ++i) a[i] = a[i] * (ll) nv % P;
    }
  }
} ntt;

struct Poly {
  vector<int> a;

  Poly(int v = 0) : a(1) {
    if ((v %= P) < 0)
      v += P;
    a[0] = v;
  }

  Poly(const vector<int> &a) : a(a) {}

  Poly(initializer_list<int> init) : a(init) {}

  // Helps
  int operator[](int k) const { return k < a.size() ? a[k] : 0; }

  int &operator[](int k) {
    if (k >= a.size())
      a.resize(k + 1);
    return a[k];
  }

  int deg() const { return a.size() - 1; }

  void redeg(int d) { a.resize(d + 1); }

  Poly monic() const;

  Poly sunic() const;

  Poly slice(int d) const {
    if (d < a.size())
      return vector<int>(a.begin(), a.begin() + d + 1);
    vector<int> res(a);
    res.resize(d + 1);
    return res;
  }

  int *base() { return a.data(); }

  const int *base() const { return a.data(); }

  Poly println(FILE *fp) const {
    fprintf(fp, "%d", a[0]);
    for (int i = 1; i < a.size(); ++i)
      fprintf(fp, " %d", a[i]);
    fputc('\n', fp);
    return *this;
  }

  // Calculations
  Poly operator+(const Poly &rhs) const {
    vector<int> res(max(a.size(), rhs.a.size()));
    for (int i = 0; i < res.size(); ++i)
      if ((res[i] = operator[](i) + rhs[i]) >= P)
        res[i] -= P;
    return res;
  }

  Poly operator-() const {
    Poly ret(a);
    for (int i = 0; i < a.size(); ++i)
      if (ret[i])
        ret[i] = P - ret[i];
    return ret;
  }

  Poly operator-(const Poly &rhs) const { return operator+(-rhs); }

  Poly operator*(const Poly &rhs) const;

  Poly operator/(const Poly &rhs) const;

  Poly operator%(const Poly &rhs) const;

  Poly der() const; // default: remove trailing
  Poly integ() const;

  Poly inv() const;

  Poly sqrt() const;

  Poly ln() const;

  Poly exp() const;

  pair<Poly, Poly> sqrti() const;

  pair<Poly, Poly> expi() const;

  Poly quo(const Poly &rhs) const;

  pair<Poly, Poly> iquo(const Poly &rhs) const;

  pair<Poly, Poly> div(const Poly &rhs) const;

  Poly taylor(int k) const;

  Poly pow(int k) const;

  Poly exp(int k) const;
};

Poly zeroes(int deg) { return vector<int>(deg + 1); }

struct Newton {
  void inv(const Poly &f, const Poly &nttf, Poly &g, const Poly &nttg, int t) {
    int n = 1 << t;
    Poly prod = nttf;
    for (int i = 0; i < (n << 1); ++i)
      prod[i] = prod[i] * (ll) nttg[i] % P;
    ntt.fft(prod.base(), t + 1, -1);
    for (int i = 0; i < n; ++i)
      prod[i] = 0;
    ntt.fft(prod.base(), t + 1, 1);
    for (int i = 0; i < (n << 1); ++i)
      prod[i] = prod[i] * (ll) nttg[i] % P;
    ntt.fft(prod.base(), t + 1, -1);
    for (int i = 0; i < n; ++i)
      prod[i] = 0;
    g = g - prod;
  }

  void inv(const Poly &f, const Poly &nttf, Poly &g, int t) {
    Poly nttg = g;
    nttg.redeg((2 << t) - 1);
    ntt.fft(nttg.base(), t + 1, 1);
    inv(f, nttf, g, nttg, t);
  }

  void inv(const Poly &f, Poly &g, int t) {
    Poly nttg = g;
    nttg.redeg((2 << t) - 1);
    ntt.fft(nttg.base(), t + 1, 1);
    Poly nttf = f;
    nttf.redeg((2 << t) - 1);
    ntt.fft(nttf.base(), t + 1, 1);
    inv(f, nttf, g, nttg, t);
  }

  void sqrt(const Poly &f, Poly &g, Poly &nttg, Poly &h, int t) {
    for (int i = 0; i < (1 << t); ++i)
      nttg[i] = mpow(nttg[i], 2);
    ntt.fft(nttg.base(), t, -1);
    nttg = nttg - f;
    for (int i = 0; i < (1 << t); ++i)
      if ((nttg[i + (1 << t)] += nttg[i]) >= P)
        nttg[i + (1 << t)] -= P;
    memset(nttg.base(), 0, sizeof(int) << t);
    ntt.fft(nttg.base(), t + 1, 1);
    Poly tmp = h;
    tmp.redeg((2 << t) - 1);
    ntt.fft(tmp.base(), t + 1, 1);
    for (int i = 0; i < (2 << t); ++i)
      tmp[i] = tmp[i] * (ll) nttg[i] % P;
    ntt.fft(tmp.base(), t + 1, -1);
    memset(tmp.base(), 0, sizeof(int) << t);
    g = g - tmp * nt.inv(2);
  }

  void exp(const Poly &f, Poly &g, Poly &nttg, Poly &h, int t) {
    Poly ntth(h);
    ntt.fft(ntth.base(), t, 1);
    Poly dg = g.der().slice((1 << t) - 1);
    ntt.fft(dg.base(), t, 1);
    Poly tmp = zeroes((1 << t) - 1);
    for (int i = 0; i < (1 << t); ++i) {
      tmp[i] = nttg[i << 1] * (ll) ntth[i] % P;
      dg[i] = dg[i] * (ll) ntth[i] % P;
    }
    ntt.fft(tmp.base(), t, -1);
    ntt.fft(dg.base(), t, -1);
    if (--tmp[0] < 0)
      tmp[0] = P - 1;
    dg.redeg((2 << t) - 1);
    Poly df0 = f.der().slice((1 << t) - 1);
    df0[(1 << t) - 1] = 0;
    for (int i = 0; i < (1 << t); ++i) {
      if ((dg[i | 1 << t] = dg[i] - df0[i]) < 0)
        dg[i | 1 << t] += P;
    }
    memcpy(dg.base(), df0.base(), sizeof(int) * ((1 << t) - 1));
    tmp.redeg((2 << t) - 1);
    ntt.fft(tmp.base(), t + 1, 1);
    df0.redeg((2 << t) - 1);
    ntt.fft(df0.base(), t + 1, 1);
    for (int i = 0; i < (2 << t); ++i)
      df0[i] = df0[i] * (ll) tmp[i] % P;
    ntt.fft(df0.base(), t + 1, -1);
    memcpy(df0.base() + (1 << t), df0.base(), sizeof(int) << t);
    memset(df0.base(), 0, sizeof(int) << t);
    dg = (dg - df0).integ().slice((2 << t) - 1) - f;
    ntt.fft(dg.base(), t + 1, 1);
    for (int i = 0; i < (2 << t); ++i)
      tmp[i] = dg[i] * (ll) nttg[i] % P;
    ntt.fft(tmp.base(), t + 1, -1);
    g.redeg((2 << t) - 1);
    for (int i = 1 << t; i < (2 << t); ++i)
      if (tmp[i])
        g[i] = P - tmp[i];
  }
} nit;

struct SemiRelaxedConvolution {

  template<class Function>
  void run(const vector<int> &a, vector<int> &b, const Function &relax) {
    int n = a.size() - 1;
    function<void(int, int)> divideConquer = [&](int l, int r) {
      if (r - l <= BRUTE_N2_LIMIT) {
        for (int i = l; i <= r; ++i) {
          for (int j = l; j < i; ++j)
            b[i] = (b[i] + b[j] * (ll) a[i - j]) % P;
          relax(i);
        }
        return;
      }
      int lg = 31 - __builtin_clz(r - l);
      int d = (r - l) / lg + 1;
      int lgd = 0;
      while ((1 << lgd) < d) ++lgd;
      ++lgd;

      vector<int> top((lg << (lgd + 1)));
      for (int i = 0; i < lg; ++i) {
        copy(a.begin() + i * d, a.begin() + min((i + 2) * d, n + 1), top.begin() + (i << lgd));
        ntt.fft(top.data() + (i << lgd), lgd, 1);
      }
      for (int i = 0; i < lg; ++i) {
        if (i)
          ntt.fft(top.data() + ((lg + i) << lgd), lgd, -1);
        for (int j = 0; j < min(d, r - l - i * d + 1); ++j)
          b[l + i * d + j] = norm(b[l + i * d + j] + top[((lg + i) << lgd) + d + j]);
        divideConquer(l + i * d, min(l + (i + 1) * d - 1, r));
        if (i + 1 < lg) {
          copy(b.begin() + l + i * d, b.begin() + min(l + (i + 1) * d, n + 1), top.begin() + ((lg + i) << lgd));
          fill(top.data() + ((lg + i) << lgd) + d, top.data() + ((lg + i + 1) << lgd), 0);
          ntt.fft(top.data() + ((lg + i) << lgd), lgd, 1);
        }
        for (int j = i + 1; j < lg; ++j) {
          for (int k = 0; k < (1 << lgd); ++k)
            top[((lg + j) << lgd) + k] =
                    (top[((lg + j) << lgd) + k] + top[((j - i - 1) << lgd) + k] * (ll) top[((lg + i) << lgd) + k]) % P;
        }
      }
    };
    divideConquer(0, n);
  }
} src;

struct Transposition {

  vector<int> _mul(int l, vector<int> res, const Poly &b) {
    vector<int> tmp(1 << l);
    memcpy(tmp.data(), b.a.data(), sizeof(int) * (b.deg() + 1));
    reverse(tmp.begin() + 1, tmp.end());
    ntt.fft(tmp.data(), l, 1);
    for (int i = 0; i < (1 << l); ++i)
      res[i] = res[i] * (ll) tmp[i] % P;
    ntt.fft(res.data(), l, -1);
    return res;
  }

  Poly bfMul(const Poly &a, const Poly &b) {
    int n = a.deg(), m = b.deg();
    Poly ret = zeroes(n - m);
    for (int i = 0; i <= n - m; ++i)
      for (int j = 0; j <= m; ++j)
        ret[i] = (ret[i] + a[i + j] * (ll) b[j]) % P;
    return ret;
  }

  Poly mul(const Poly &a, const Poly &b) {
    if (a.deg() < b.deg()) return 0;
    if (a.deg() <= BRUTE_N2_LIMIT) return bfMul(a, b);
    int l = 0;
    while ((1 << l) <= a.deg()) ++l;
    vector<int> res(1 << l);
    memcpy(res.data(), a.a.data(), sizeof(int) * (a.deg() + 1));
    ntt.fft(res.data(), l, 1);
    res = _mul(l, res, b);
    res.resize(a.deg() - b.deg() + 1);
    return res;
  }

  pair<Poly, Poly> mul2(const Poly &a, const Poly &b1, const Poly &b2) {
    if (a.deg() <= BRUTE_N2_LIMIT) return make_pair(bfMul(a, b1), bfMul(a, b2));
    int l = 0;
    while ((1 << l) <= a.deg()) ++l;
    vector<int> fa(1 << l);
    memcpy(fa.data(), a.a.data(), sizeof(int) * (a.deg() + 1));
    ntt.fft(fa.data(), l, 1);
    vector<int> res1 = _mul(l, fa, b1), res2 = _mul(l, fa, b2);
    res1.resize(a.deg() - b1.deg() + 1);
    res2.resize(a.deg() - b2.deg() + 1);
    return make_pair(res1, res2);
  }

  vector<int> ls, rs, pos;
  vector<Poly> p, q;

  void _build(int n) {
    ls.assign(n * 2 - 1, 0);
    rs.assign(n * 2 - 1, 0);
    p.assign(n * 2 - 1, 0);
    q.assign(n * 2 - 1, 0);
    pos.resize(n);
    int cnt = 0;
    function<int(int, int)> dfs = [&](int l, int r) {
      if (l == r) {
        pos[l] = cnt;
        return cnt++;
      }
      int ret = cnt++;
      int mid = (l + r) >> 1;
      ls[ret] = dfs(l, mid);
      rs[ret] = dfs(mid + 1, r);
      return ret;
    };
    dfs(0, n - 1);
  }

  vector<int> _eval(vector<int> f, const vector<int> &x) {
    int n = f.size();
    _build(n);
    for (int i = 0; i < n; ++i)
      q[pos[i]] = {1, norm(P - x[i])};
    for (int i = n * 2 - 2; i >= 0; --i)
      if (ls[i])
        q[i] = q[ls[i]] * q[rs[i]];
    f.resize(n * 2);
    p[0] = mul(f, q[0].inv());
    for (int i = 0; i < n * 2 - 1; ++i)
      if (ls[i])
        tie(p[ls[i]], p[rs[i]]) = mul2(p[i], q[rs[i]], q[ls[i]]);
    vector<int> ret(n);
    for (int i = 0; i < n; ++i)
      ret[i] = p[pos[i]][0];
    return ret;
  }

  vector<int> eval(const Poly &f, const vector<int> &x) {
    int n = f.deg() + 1, m = x.size();
    vector<int> tmpf = f.a, tmpx = x;
    tmpf.resize(max(n, m));
    tmpx.resize(max(n, m));
    vector<int> ret = _eval(tmpf, tmpx);
    ret.resize(m);
    return ret;
  }

  Poly inter(const vector<int> &x, const vector<int> &y) {
    int n = x.size();
    _build(n);
    for (int i = 0; i < n; ++i)
      q[pos[i]] = {1, norm(P - x[i])};
    for (int i = n * 2 - 2; i >= 0; --i)
      if (ls[i])
        q[i] = q[ls[i]] * q[rs[i]];
    Poly tmp = q[0];
    reverse(tmp.a.begin(), tmp.a.end());
    vector<int> f = tmp.der().a;
    f.resize(n * 2);
    p[0] = mul(f, q[0].inv());
    for (int i = 0; i < n * 2 - 1; ++i)
      if (ls[i])
        tie(p[ls[i]], p[rs[i]]) = mul2(p[i], q[rs[i]], q[ls[i]]);
    for (int i = 0; i < n; ++i)
      p[pos[i]] = nt.inv(p[pos[i]][0]) * (ll) y[i] % P;
    for (int i = 0; i < n * 2 - 1; ++i)
      reverse(q[i].a.begin(), q[i].a.end());
    for (int i = n * 2 - 2; i >= 0; --i)
      if (ls[i])
        p[i] = p[ls[i]] * q[rs[i]] + p[rs[i]] * q[ls[i]];
    return p[0];
  }

} tp;

Poly operator "" _z(unsigned long long a) { return {0, (int) a}; }

Poly operator+(int v, const Poly &rhs) { return Poly(v) + rhs; }

Poly Poly::operator*(const Poly &rhs) const {
  int n = deg(), m = rhs.deg();
  if (n <= 10 || m <= 10 || n + m <= BRUTE_N2_LIMIT) {
    Poly ret = zeroes(n + m);
    for (int i = 0; i <= n; ++i)
      for (int j = 0; j <= m; ++j)
        ret[i + j] = (ret[i + j] + a[i] * (ll) rhs[j]) % P;
    return ret;
  }
  n += m;
  int l = 0;
  while ((1 << l) <= n)
    ++l;
  vector<int> res(1 << l), tmp(1 << l);
  memcpy(res.data(), base(), a.size() * sizeof(int));
  ntt.fft(res.data(), l, 1);
  memcpy(tmp.data(), rhs.base(), rhs.a.size() * sizeof(int));
  ntt.fft(tmp.data(), l, 1);
  for (int i = 0; i < (1 << l); ++i)
    res[i] = res[i] * (ll) tmp[i] % P;
  ntt.fft(res.data(), l, -1);
  res.resize(n + 1);
  return res;
}

Poly Poly::inv() const {
  Poly g = nt.inv(a[0]);
  for (int t = 0; (1 << t) <= deg(); ++t)
    nit.inv(slice((2 << t) - 1), g, t);
  g.redeg(deg());
  return g;
}

Poly Poly::taylor(int k) const {
  int n = deg();
  Poly t = zeroes(n);
  simp.check(n);
  for (int i = 0; i <= n; ++i)
    t[n - i] = a[i] * (ll) simp.fac[i] % P;
  int pw = 1;
  Poly help = vector<int>(simp.ifac.begin(), simp.ifac.begin() + n + 1);
  for (int i = 0; i <= n; ++i) {
    help[i] = help[i] * (ll) pw % P;
    pw = pw * (ll) k % P;
  }
  t = t * help;
  for (int i = 0; i <= n; ++i)
    help[i] = t[n - i] * (ll) simp.ifac[i] % P;
  return help;
}

Poly Poly::pow(int k) const {
  if (k == 0)
    return 1;
  if (k == 1)
    return *this;
  int n = deg() * k;
  int lgn = 0;
  while ((1 << lgn) <= n)
    ++lgn;
  vector<int> val = a;
  val.resize(1 << lgn);
  ntt.fft(val.data(), lgn, 1);
  for (int i = 0; i < (1 << lgn); ++i)
    val[i] = mpow(val[i], k);
  ntt.fft(val.data(), lgn, -1);
  return val;
}

Poly Poly::der() const {
  if (deg() == 0)
    return 0;
  vector<int> res(deg());
  for (int i = 0; i < deg(); ++i)
    res[i] = a[i + 1] * (ll) (i + 1) % P;
  return res;
}

Poly Poly::integ() const {
  vector<int> res(deg() + 2);
  simp.check(deg() + 1);
  for (int i = 0; i <= deg(); ++i)
    res[i + 1] = a[i] * (ll) simp.inv[i + 1] % P;
  return res;
}

Poly Poly::quo(const Poly &rhs) const {
  if (rhs.deg() == 0)
    return a[0] * (ll) nt.inv(rhs[0]) % P;
  Poly g = nt.inv(rhs[0]);
  int t = 0, n;
  for (n = 1; (n << 1) <= rhs.deg(); ++t, n <<= 1)
    nit.inv(rhs.slice((n << 1) - 1), g, t);
  Poly nttg = g;
  nttg.redeg((n << 1) - 1);
  ntt.fft(nttg.base(), t + 1, 1);
  Poly eps1 = rhs.slice((n << 1) - 1);
  ntt.fft(eps1.base(), t + 1, 1);
  for (int i = 0; i < (n << 1); ++i)
    eps1[i] = eps1[i] * (ll) nttg[i] % P;
  ntt.fft(eps1.base(), t + 1, -1);
  memcpy(eps1.base(), eps1.base() + n, sizeof(int) << t);
  memset(eps1.base() + n, 0, sizeof(int) << t);
  ntt.fft(eps1.base(), t + 1, 1);
  Poly h0 = slice(n - 1);
  h0.redeg((n << 1) - 1);
  ntt.fft(h0.base(), t + 1);
  Poly h0g0 = zeroes((n << 1) - 1);
  for (int i = 0; i < (n << 1); ++i)
    h0g0[i] = h0[i] * (ll) nttg[i] % P;
  ntt.fft(h0g0.base(), t + 1, -1);
  Poly h0eps1 = zeroes((n << 1) - 1);
  for (int i = 0; i < (n << 1); ++i)
    h0eps1[i] = h0[i] * (ll) eps1[i] % P;
  ntt.fft(h0eps1.base(), t + 1, -1);
  for (int i = 0; i < n; ++i) {
    h0eps1[i] = operator[](i + n) - h0eps1[i];
    if (h0eps1[i] < 0)
      h0eps1[i] += P;
  }
  memset(h0eps1.base() + n, 0, sizeof(int) << t);
  ntt.fft(h0eps1.base(), t + 1);
  for (int i = 0; i < (n << 1); ++i)
    h0eps1[i] = h0eps1[i] * (ll) nttg[i] % P;
  ntt.fft(h0eps1.base(), t + 1, -1);
  memcpy(h0eps1.base() + n, h0eps1.base(), sizeof(int) << t);
  memset(h0eps1.base(), 0, sizeof(int) << t);
  return (h0g0 + h0eps1).slice(rhs.deg());
}

Poly Poly::ln() const {
  if (deg() == 0)
    return 0;
  return der().quo(slice(deg() - 1)).integ();
}

pair<Poly, Poly> Poly::sqrti() const {
  Poly g = nt.sqrt(a[0]), h = nt.inv(g[0]), nttg = g;
  for (int t = 0; (1 << t) <= deg(); ++t) {
    nit.sqrt(slice((2 << t) - 1), g, nttg, h, t);
    nttg = g;
    ntt.fft(nttg.base(), t + 1, 1);
    nit.inv(g, nttg, h, t);
  }
  return make_pair(g.slice(deg()), h.slice(deg()));
}

Poly Poly::sqrt() const {
  Poly g = nt.sqrt(a[0]), h = nt.inv(g[0]), nttg = g;
  for (int t = 0; (1 << t) <= deg(); ++t) {
    nit.sqrt(slice((2 << t) - 1), g, nttg, h, t);
    if ((2 << t) <= deg()) {
      nttg = g;
      ntt.fft(nttg.base(), t + 1, 1);
      nit.inv(g, nttg, h, t);
    }
  }
  return g.slice(deg());
}

Poly Poly::exp() const {
  vector<int> der(a), ret(a.size());
  for (int i = 0; i < a.size(); ++i)
    der[i] = der[i] * (ll) i % P;
  src.run(der, ret, [&](int i) {
    if (i == 0) ret[0] = 1;
    else ret[i] = ret[i] * (ll) simp.ginv(i) % P;
  });
  return ret;
}

pair<Poly, Poly> Poly::expi() const {
  Poly g = 1, h = 1, nttg = {1, 1};
  for (int t = 0; (1 << t) <= deg(); ++t) {
    nit.exp(slice((2 << t) - 1), g, nttg, h, t);
    nttg = g;
    nttg.redeg((4 << t) - 1);
    ntt.fft(nttg.base(), t + 2);
    Poly f2n = zeroes((2 << t) - 1);
    for (int i = 0; i < (2 << t); ++i)
      f2n[i] = nttg[i << 1];
    nit.inv(g, f2n, h, t);
  }
  return make_pair(g.slice(deg()), h.slice(deg()));
}

Poly Poly::exp(int k) const {
  int lead, lz = 0;
  while (lz < deg() && !a[lz])
    ++lz;
  if (lz == deg() && !a[lz])
    return *this;
  lead = a[lz];
  if (lz * (ll) k > deg())
    return zeroes(deg());
  Poly part = Poly(vector<int>(a.begin() + lz, a.begin() + deg() - lz * (k - 1) + 1)) * nt.inv(lead);
  part = (part.ln() * k).exp() * mpow(lead, k);
  vector<int> ret(deg() + 1);
  memcpy(ret.data() + lz * k, part.base(), sizeof(int) * (deg() - lz * k + 1));
  return ret;
}

Poly Poly::operator/(const Poly &rhs) const {
  int n = deg(), m = rhs.deg();
  if (n < m)
    return 0;
  Poly ta(vector<int>(a.rbegin(), a.rend())),
          tb(vector<int>(rhs.a.rbegin(), rhs.a.rend()));
  ta.redeg(n - m);
  tb.redeg(n - m);
  Poly q = ta.quo(tb);
  reverse(q.a.begin(), q.a.end());
  return q;
}

pair<Poly, Poly> Poly::div(const Poly &rhs) const {
  if (deg() < rhs.deg())
    return make_pair(0, *this);
  int n = deg(), m = rhs.deg();
  Poly q = operator/(rhs), r;
  int lgn = 0;
  while ((1 << lgn) < rhs.deg())
    ++lgn;
  int t = (1 << lgn) - 1;
  r = zeroes(t);
  Poly tmp = zeroes(t);
  for (int i = 0; i <= m; ++i)
    if ((r[i & t] += rhs[i]) >= P)
      r[i & t] -= P;
  for (int i = 0; i <= n - m; ++i)
    if ((tmp[i & t] += q[i]) >= P)
      tmp[i & t] -= P;
  ntt.fft(r.base(), lgn, 1);
  ntt.fft(tmp.base(), lgn, 1);
  for (int i = 0; i <= t; ++i)
    tmp[i] = tmp[i] * (ll) r[i] % P;
  ntt.fft(tmp.base(), lgn, -1);
  memset(r.base(), 0, sizeof(int) << lgn);
  for (int i = 0; i <= n; ++i)
    if ((r[i & t] += a[i]) >= P)
      r[i & t] -= P;
  for (int i = 0; i < m; ++i)
    if ((r[i] -= tmp[i]) < 0)
      r[i] += P;
  return make_pair(q, r.slice(m - 1));
}

Poly Poly::operator%(const Poly &rhs) const {
  if (deg() < rhs.deg())
    return *this;
  return div(rhs).second;
}


#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cctype>

#include <algorithm>
#include <random>
#include <bitset>
#include <queue>
#include <functional>
#include <set>
#include <map>
#include <vector>
#include <chrono>
#include <iostream>
#include <limits>
#include <numeric>

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;
typedef unsigned long long ull;

// mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

template <class T>
istream& operator>>(istream& is, vector<T>& v) {
  for (T& x : v)
    is >> x;
  return is;
}

template <class T>
ostream& operator<<(ostream& os, const vector<T>& v) {
  if (!v.empty()) {
    os << v.front();
    for (int i = 1; i < v.size(); ++i)
      os << ' ' << v[i];
  }
  return os;
}

int main() {
#ifdef ELEGIA
  freopen("test.in", "r", stdin);
  int nol_cl = clock();
#endif
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int m, k;
  cin >> m >> k;
  vector<int> a(m);
  cin >> a;
  sort(a.begin(), a.end());
  a.push_back(-1);
  vector<int> lens;
  int cur = 0;
  for (int i = 0; i + 1 < a.size(); ++i) {
    ++cur;
    if (a[i] + 1 == a[i + 1])
      continue;
    else {
      lens.push_back(cur);
      cur = 0;
    }
  }

  function<Poly(int)> compinv = [&](int n) {
    if (n == 1) return 1_z;
    auto half = compinv(n / 2);
    half.redeg(n + 1);
    auto comp = (half - half.exp(k + 1)).quo(Poly(1) - half);
    auto der = comp.der().quo(half.der());
    return (half - (comp - 1_z).quo(der)).slice(n);
  };
  Poly tmp = compinv(*max_element(lens.begin(), lens.end()) + 1);

  priority_queue<Poly, vector<Poly>, function<bool(const Poly&, const Poly&)>> que([&](const Poly& a, const Poly& b) {
    return a.deg() > b.deg();
  });

  for (int n : lens) {
    Poly pwr = tmp.slice(n + 1);
    pwr.a.erase(pwr.a.begin());
    pwr = pwr.exp(P - (n + 1));
    Poly res = zeroes(n);
    for (int i = 0; i <= n; ++i) {
      res[n - i] = pwr[n - i] * (ll)simp.ginv(n + 1) % P * (1 + i) % P;
    }
    que.push(res);
  }

  while (que.size() > 1) {
    auto x = que.top(); que.pop();
    auto y = que.top(); que.pop();
    que.push(x * y);
  }
  int ans = 0;
  Poly pol = que.top(); que.pop();
  //pol.println(stderr);
  for (int i = 0; i < m; ++i) {
    int inner = nt.inv(simp.binom(m, i)) * (ll)m % P * simp.ginv(m - i) % P;
    ans = (ans + inner * (ll)pol[i]) % P;
  }
  cout << ans << '\n';

#ifdef ELEGIA
  LOG("Time: %dms\n", int ((clock()
          -nol_cl) / (double)CLOCKS_PER_SEC * 1000));
#endif
  return 0;
}

Details

Tip: Click on the bar to expand more detailed information

Pretests


Final Tests

Test #1:

score: 10
Accepted
time: 1ms
memory: 3876kb

input:

10 3
1 2 3 5 6 7 8 9 10 11

output:

23767731

result:

ok 1 number(s): "23767731"

Test #2:

score: 10
Accepted
time: 1ms
memory: 3768kb

input:

500 499
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 10...

output:

137319993

result:

ok 1 number(s): "137319993"

Test #3:

score: 10
Accepted
time: 1ms
memory: 3648kb

input:

500 13
397 284 276 435 348 557 14 561 516 352 125 56 240 337 417 37 2 281 414 307 210 336 232 382 224 122 58 133 137 393 488 92 469 238 48 453 241 533 492 424 260 213 487 571 33 178 97 158 135 575 292 331 478 118 121 11 85 552 460 219 415 272 293 479 481 304 254 55 212 351 230 164 472 422 7 28 19 39...

output:

93022584

result:

ok 1 number(s): "93022584"

Test #4:

score: 10
Accepted
time: 0ms
memory: 3940kb

input:

500 9
464 333 321 508 405 641 14 645 599 412 147 69 279 394 488 40 2 328 486 359 243 393 270 444 261 145 71 158 162 458 568 111 547 278 59 529 280 617 571 496 302 247 567 658 35 209 115 183 160 665 343 386 556 139 144 11 103 637 539 254 487 318 344 557 558 356 293 68 246 408 268 189 549 494 7 30 20 ...

output:

508272560

result:

ok 1 number(s): "508272560"

Test #5:

score: 10
Accepted
time: 2ms
memory: 3772kb

input:

500 30
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 ...

output:

730181983

result:

ok 1 number(s): "730181983"

Test #6:

score: 10
Accepted
time: 11ms
memory: 3940kb

input:

5000 100
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100...

output:

979829264

result:

ok 1 number(s): "979829264"

Test #7:

score: 10
Accepted
time: 15ms
memory: 3916kb

input:

5000 300
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100...

output:

369651884

result:

ok 1 number(s): "369651884"

Test #8:

score: 10
Accepted
time: 1305ms
memory: 22876kb

input:

200000 5
1 2 3 4 5 6 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 ...

output:

297583105

result:

ok 1 number(s): "297583105"

Test #9:

score: 10
Accepted
time: 1275ms
memory: 20220kb

input:

200000 2000
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 ...

output:

645299783

result:

ok 1 number(s): "645299783"

Test #10:

score: 10
Accepted
time: 1331ms
memory: 21472kb

input:

200000 50
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 69 70 71 72 73 74 75 76 77 78 79 80 81 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 ...

output:

66899726

result:

ok 1 number(s): "66899726"

Extra Test:

score: 0
Extra Test Passed