QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#715383#5546. Sharing Breaducup-team3519AC ✓225ms33216kbC++1710.8kb2024-11-06 11:43:142024-11-06 11:43:14

Judging History

你现在查看的是最新测评结果

  • [2024-11-06 11:43:14]
  • 评测
  • 测评结果:AC
  • 用时:225ms
  • 内存:33216kb
  • [2024-11-06 11:43:14]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
#define V vector
#define all0(x) (x).begin(),(x).end()
#define all1(x) (x).begin()+1,(x).end()
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define cin std::cin
#define cout std::cout
typedef long long LL;
typedef pair<int, int> pi;
typedef pair<LL, LL> pl;

//const int N = 2e5 + 20;
const int INF = 2e9 + 1000;
const LL INFLL = 8e18 + 1000;
mt19937 mrand(chrono::steady_clock().now().time_since_epoch().count());
//模板区域~~~~~~~
const int mod = 998244353;

inline void add(int &x, int y) {
  x += y;
  if (x >= mod) {
    x -= mod;
  }
}

inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += mod;
  }
}

inline int mul(int x, int y) {
  return (long long) x * y % mod;
}

inline int power(int x, int y) {
  int res = 1;
  for (; y; y >>= 1, x = mul(x, x)) {
    if (y & 1) {
      res = mul(res, x);
    }
  }
  return res;
}

inline int inv(int a) {
  a %= mod;
  if (a < 0) {
    a += mod;
  }
  int b = mod, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += mod;
  }
  return u;
}

namespace ntt {
int base = 1, root = -1, max_base = -1;
V<int> rev = {0, 1}, roots = {0, 1};

void init() {
  int temp = mod - 1;
  max_base = 0;
  while (temp % 2 == 0) {
    temp >>= 1;
    ++max_base;
  }
  root = 2;
  while (true) {
    if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
      break;
    }
    ++root;
  }
}

void ensure_base(int nbase) {
  if (max_base == -1) {
    init();
  }
  if (nbase <= base) {
    return;
  }
  assert(nbase <= max_base);
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) {
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
  }
  roots.resize(1 << nbase);
  while (base < nbase) {
    int z = power(root, 1 << (max_base - 1 - base));
    for (int i = 1 << (base - 1); i < 1 << base; ++i) {
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = mul(roots[i], z);
    }
    ++base;
  }
}

void dft(V<int> &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
        a[j + k] = (x + y) % mod;
        a[j + k + i] = (x + mod - y) % mod;
      }
    }
  }
}

V<int> multiply(V<int> a, V<int> b) {
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  a.resize(sz);
  b.resize(sz);
  bool equal = a == b;
  dft(a);
  if (equal) {
    b = a;
  } else {
    dft(b);
  }
  int inv_sz = inv(sz);
  for (int i = 0; i < sz; ++i) {
    a[i] = mul(mul(a[i], b[i]), inv_sz);
  }
  reverse(a.begin() + 1, a.end());
  dft(a);
  a.resize(need);
  return a;
}

V<int> inverse(V<int> a) {
  int n = a.size(), m = (n + 1) >> 1;
  if (n == 1) {
    return V<int>(1, inv(a[0]));
  } else {
    V<int> b = inverse(V<int>(a.begin(), a.begin() + m));
    int need = n << 1, nbase = 0;
    while (1 << nbase < need) {
      ++nbase;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; ++i) {
      a[i] = mul(mul(mod + 2 - mul(a[i], b[i]), b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    dft(a);
    a.resize(n);
    return a;
  }
}
}

using ntt::multiply;
using ntt::inverse;

V<int>& operator += (V<int> &a, const V<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    add(a[i], b[i]);
  }
  return a;
}

V<int> operator + (const V<int> &a, const V<int> &b) {
  V<int> c = a;
  return c += b;
}

V<int>& operator -= (V<int> &a, const V<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    sub(a[i], b[i]);
  }
  return a;
}

V<int> operator - (const V<int> &a, const V<int> &b) {
  V<int> c = a;
  return c -= b;
}

V<int>& operator *= (V<int> &a, const V<int> &b) {
  if (min(a.size(), b.size()) < 128) {
    V<int> c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < c.size(); ++i) {
      for (int j = 0; j < b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}

V<int> operator * (const V<int> &a, const V<int> &b) {
  V<int> c = a;
  return c *= b;
}

V<int>& operator /= (V<int> &a, const V<int> &b) {
  int n = a.size(), m = b.size();
  if (n < m) {
    a.clear();
  } else {
    V<int> c = b;
    reverse(a.begin(), a.end());
    reverse(c.begin(), c.end());
    c.resize(n - m + 1);
    a *= inverse(c);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}

V<int> operator / (const V<int> &a, const V<int> &b) {
  V<int> c = a;
  return c /= b;
}

V<int>& operator %= (V<int> &a, const V<int> &b) {
  int n = a.size(), m = b.size();
  if (n >= m) {
    V<int> c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; ++i) {
      sub(a[i], c[i]);
    }
  }
  return a;
}

V<int> operator % (const V<int> &a, const V<int> &b) {
  V<int> c = a;
  return c %= b;
}

V<int> derivative(const V<int> &a) {
  int n = a.size();
  V<int> b(n - 1);
  for (int i = 1; i < n; ++i) {
    b[i - 1] = mul(a[i], i);
  }
  return b;
}

V<int> primitive(const V<int> &a) {
  int n = a.size();
  V<int> b(n + 1), invs(n + 1);
  for (int i = 1; i <= n; ++i) {
    invs[i] = i == 1 ? 1 : mul(mod - mod / i, invs[mod % i]);
    b[i] = mul(a[i - 1], invs[i]);
  }
  return b;
}

V<int> logarithm(const V<int> &a) {
  V<int> b = primitive(derivative(a) * inverse(a));
  b.resize(a.size());
  return b;
}

V<int> exponent(const V<int> &a) {
  V<int> b(1, 1);
  while (b.size() < a.size()) {
    V<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    add(c[0], 1);
    V<int> old_b = b;
    b.resize(b.size() << 1);
    c -= logarithm(b);
    c *= old_b;
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = c[i];
    }
  }
  b.resize(a.size());
  return b;
}

V<int> power(V<int> a, int m) {
  int n = a.size(), p = -1;
  V<int> b(n);
  for (int i = 0; i < n; ++i) {
    if (a[i]) {
      p = i;
      break;
    }
  }
  if (p == -1) {
    b[0] = !m;
    return b;
  }
  if ((long long) m * p >= n) {
    return b;
  }
  int mu = power(a[p], m), di = inv(a[p]);
  V<int> c(n - m * p);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(a[i + p], di);
  }
  c = logarithm(c);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(c[i], m);
  }
  c = exponent(c);
  for (int i = 0; i < n - m * p; ++i) {
    b[i + m * p] = mul(c[i], mu);
  }
  return b;
}

V<int> sqrt(const V<int> &a) {
  V<int> b(1, 1);
  while (b.size() < a.size()) {
    V<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    V<int> old_b = b;
    b.resize(b.size() << 1);
    c *= inverse(b);
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = mul(c[i], (mod + 1) >> 1);
    }
  }
  b.resize(a.size());
  return b;
}

V<int> multiply_all(int l, int r, V<V<int>> &all) {
  if (l > r) {
    return V<int>();
  } else if (l == r) {
    return all[l];
  } else {
    int y = (l + r) >> 1;
    return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
  }
}

V<int> evaluate(const V<int> &f, const V<int> &x) {
  int n = x.size();
  if (!n) {
    return V<int>();
  }
  V<V<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = V<int>{(mod - x[i]) % mod, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  V<V<int>> down(n * 2);
  down[1] = f % up[1];
  for (int i = 2; i < n * 2; ++i) {
    down[i] = down[i >> 1] % up[i];
  }
  V<int> y(n);
  for (int i = 0; i < n; ++i) {
    y[i] = down[i + n][0];
  }
  return y;
}

V<int> interpolate(const V<int> &x, const V<int> &y) {
  int n = x.size();
  V<V<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = V<int>{(mod - x[i]) % mod, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  V<int> a = evaluate(derivative(up[1]), x);
  for (int i = 0; i < n; ++i) {
    a[i] = mul(y[i], inv(a[i]));
  }
  V<V<int>> down(n * 2);
  for (int i = 0; i < n; ++i) {
    down[i + n] = V<int>(1, a[i]);
  }
  for (int i = n - 1; i; --i) {
    down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
  }
  return down[1];
}


constexpr LL qpow(LL a, LL k) {
	LL ans = 1;
	while (k) {
		if (k & 1) ans = ans * a % mod;
		k >>= 1;
		a = a * a % mod;
	}
	return ans;
}

V<LL> _fac, _ifac;
void ini_comb(int n) {
    _fac.resize(n + 1), _ifac.resize(n + 1);
    _fac[0] = 1;
    for (int i = 1; i <= n; ++i) {
        _fac[i] = _fac[i - 1] * i % mod;
    }
    _ifac[n] = qpow(_fac[n], mod - 2);
    for (int i = n - 1; i >= 0; --i) {
        _ifac[i] = _ifac[i + 1] * (i + 1) % mod;
    }
}
LL fac(int x) {
    return _fac[x];
}
LL ifac(int x) {
    return _ifac[x];
}
LL C(int n, int m) {
    if(m == 0) return 1;
    if(m < 0 || m > n) return 0;
    return fac(n) * ifac(n - m) % mod * ifac(m) % mod;
}
//模板结束~~~~~~~

void solve() {
    ini_comb(1e6);
    int n, m; cin >> n >> m;
    swap(n, m);
    int t = m - n;

    V<int> g(n + 1), ex(n + 1);
    for(int i = 1; i <= n; i++) {
        g[i] = 1LL * power(i - 1, i) * inv(fac(i)) % mod;
    }
    g[0] = 1;
    // assert(g[2] == mod - qpow(2, mod - 2));
    // cout << "g : " << endl;
    // for(int i = 0; i <= n; i++) cout << g[i] << " ";
    // cout << endl;

    for(int i = 0; i <= n; i++) {
        ex[i] = 1LL * power(i + t, i) * inv(fac(i)) % mod;
    }
    // assert(ex[2] == 9 * qpow(2, mod - 2) % mod);
    // cout << "ex : " << endl;
    // for(int i = 0; i <= n; i++) cout << ex[i] << " ";
    // cout << endl;
    // cout << "inv(g) : " << endl;
    V<int> test = inverse(g);
    // for(int i = 0; i <= n; i++) cout << test[i] << " ";
    // cout << endl;
    V<int> test_h = ex * test;

    // cout << "ex / g : " << endl;
    // for(int i = 0; i <= n; i++) cout << test_h[i] << " ";
    // cout << endl;

    cout << 1LL * test_h[n] * fac(n) % mod;
}

int32_t main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin.exceptions(cin.failbit);
    int t = 1;
    //cin >> t;
    while (t--) 
    solve();
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 8ms
memory: 18800kb

input:

4 3

output:

50

result:

ok single line: '50'

Test #2:

score: 0
Accepted
time: 8ms
memory: 18752kb

input:

10 1

output:

10

result:

ok single line: '10'

Test #3:

score: 0
Accepted
time: 7ms
memory: 19020kb

input:

2 2

output:

3

result:

ok single line: '3'

Test #4:

score: 0
Accepted
time: 7ms
memory: 18952kb

input:

1 1

output:

1

result:

ok single line: '1'

Test #5:

score: 0
Accepted
time: 9ms
memory: 18776kb

input:

277 277

output:

124662617

result:

ok single line: '124662617'

Test #6:

score: 0
Accepted
time: 8ms
memory: 18768kb

input:

426 1

output:

426

result:

ok single line: '426'

Test #7:

score: 0
Accepted
time: 8ms
memory: 18932kb

input:

200000 1

output:

200000

result:

ok single line: '200000'

Test #8:

score: 0
Accepted
time: 223ms
memory: 33168kb

input:

200000 200000

output:

950017432

result:

ok single line: '950017432'

Test #9:

score: 0
Accepted
time: 105ms
memory: 26212kb

input:

200000 100000

output:

280947286

result:

ok single line: '280947286'

Test #10:

score: 0
Accepted
time: 108ms
memory: 25908kb

input:

200000 84731

output:

211985425

result:

ok single line: '211985425'

Test #11:

score: 0
Accepted
time: 113ms
memory: 26720kb

input:

200000 124713

output:

716696526

result:

ok single line: '716696526'

Test #12:

score: 0
Accepted
time: 54ms
memory: 22548kb

input:

129179 49655

output:

506429515

result:

ok single line: '506429515'

Test #13:

score: 0
Accepted
time: 37ms
memory: 20856kb

input:

87518 26040

output:

808454539

result:

ok single line: '808454539'

Test #14:

score: 0
Accepted
time: 17ms
memory: 19660kb

input:

178355 10116

output:

361555714

result:

ok single line: '361555714'

Test #15:

score: 0
Accepted
time: 8ms
memory: 18804kb

input:

2 1

output:

2

result:

ok single line: '2'

Test #16:

score: 0
Accepted
time: 46ms
memory: 22556kb

input:

192733 52550

output:

67181038

result:

ok single line: '67181038'

Test #17:

score: 0
Accepted
time: 54ms
memory: 22360kb

input:

76689 36632

output:

717949287

result:

ok single line: '717949287'

Test #18:

score: 0
Accepted
time: 8ms
memory: 18916kb

input:

200000 9

output:

158524471

result:

ok single line: '158524471'

Test #19:

score: 0
Accepted
time: 225ms
memory: 33216kb

input:

200000 199998

output:

879727659

result:

ok single line: '879727659'

Test #20:

score: 0
Accepted
time: 8ms
memory: 18680kb

input:

199952 1

output:

199952

result:

ok single line: '199952'

Test #21:

score: 0
Accepted
time: 223ms
memory: 33156kb

input:

199947 199947

output:

339118685

result:

ok single line: '339118685'

Test #22:

score: 0
Accepted
time: 111ms
memory: 26080kb

input:

199956 99978

output:

135867461

result:

ok single line: '135867461'

Test #23:

score: 0
Accepted
time: 6ms
memory: 18780kb

input:

2 2

output:

3

result:

ok single line: '3'

Test #24:

score: 0
Accepted
time: 11ms
memory: 18908kb

input:

10 3

output:

968

result:

ok single line: '968'

Test #25:

score: 0
Accepted
time: 8ms
memory: 19040kb

input:

10 5

output:

87846

result:

ok single line: '87846'

Test #26:

score: 0
Accepted
time: 11ms
memory: 18724kb

input:

10 9

output:

428717762

result:

ok single line: '428717762'

Test #27:

score: 0
Accepted
time: 11ms
memory: 18832kb

input:

279 166

output:

945780025

result:

ok single line: '945780025'

Test #28:

score: 0
Accepted
time: 6ms
memory: 18916kb

input:

361 305

output:

926296326

result:

ok single line: '926296326'

Test #29:

score: 0
Accepted
time: 9ms
memory: 18724kb

input:

305 262

output:

465560336

result:

ok single line: '465560336'