QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#51102#1285. Stirling Numberckiseki#ML 58ms40028kbC++5.1kb2022-09-30 20:19:372022-09-30 20:19:41

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2022-09-30 20:19:41]
  • 评测
  • 测评结果:ML
  • 用时:58ms
  • 内存:40028kb
  • [2022-09-30 20:19:37]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

template <int mod, int G, int maxn>
class NTT {
private:
    static_assert(maxn == (maxn & (-maxn)));
    static_assert((mod - 1) % maxn == 0);
    int roots[maxn];
public:
    constexpr static int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
    constexpr static int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
    constexpr static int mul(int64_t a, int64_t b) { return static_cast<int>(a * b % mod); }
    constexpr static int qpow(int a, int64_t k) {
        int r = 1 % mod;
        while (k) {
            if (k & 1) r = mul(r, a);
            k >>= 1; a = mul(a, a);
        }
        return r;
    }
    constexpr static int minv(int a) { return qpow(a, mod - 2); }
    NTT() : roots{} {
        int r = qpow(G, (mod - 1) / maxn);
        for (int i = maxn / 2; i; i /= 2) {
            roots[i] = 1 % mod;
            for (int j = 1; j < i; ++j)
                roots[i + j] = mul(roots[i + j - 1], r);
            r = mul(r, r);
        }
    }
    void operator()(int F[], int n, bool inv = false) {
        for (int i = 0, j = 0; i < n; ++i) {
            if (i < j) swap(F[i], F[j]);
            for (int k = n / 2; (j ^= k) < k; k /= 2);
        }
        for (int s = 1; s < n; s *= 2) {
            for (int i = 0; i < n; i += s * 2) {
                for (int j = 0; j < s; ++j) {
                    int a = F[i + j];
                    int b = mul(F[i + j + s], roots[s + j]);
                    F[i + j] = add(a, b);
                    F[i + j + s] = sub(a, b);
                }
            }
        }
        if (inv) {
            int invn = minv(n);
            for (int i = 0; i < n; ++i)
                F[i] = mul(F[i], invn);
            reverse(F + 1, F + n);
        }
    }
};

static constexpr int maxn = 1 << 21;

static constexpr int M1 = 985661441;
static constexpr int M2 = 998244353;
static constexpr int M3 = 1004535809;

using NTT1 = NTT<M1, 3, maxn>;
using NTT2 = NTT<M2, 3, maxn>;
using NTT3 = NTT<M3, 3, maxn>;

NTT1 ntt1;
NTT2 ntt2;
NTT3 ntt3;

int superBigCRT(int64_t A, int64_t B, int64_t C, int mod) {
    static constexpr int64_t r12 = NTT2::qpow(M1, M2 - 2);
    static constexpr int64_t r13 = NTT3::qpow(M1, M3 - 2);
    static constexpr int64_t r23 = NTT3::qpow(M2, M3 - 2);
    const int64_t M1M2 = int64_t(M1) * M2 % mod;
    B = (B - A + M2) * r12 % M2;
    C = (C - A + M3) * r13 % M3;
    C = (C - B + M3) * r23 % M3;
    return (A + B * M1 + C * M1M2) % mod;
}

void conv(vector<int> &a, vector<int> &b, int mod) {
    const size_t sa = a.size(), sb = b.size();
    int sz = 1;
    while (sz < a.size() + b.size()) sz *= 2;
    a.resize(sz);
    b.resize(sz);

    auto a1 = a, a2 = a;
    auto b1 = b, b2 = b;

    ntt1(a1.data(), sz);
    ntt1(b1.data(), sz);
    for (int i = 0; i < sz; ++i) a1[i] = NTT1::mul(a1[i], b1[i]);
    ntt1(a1.data(), sz, true);
    
    ntt2(a2.data(), sz);
    ntt2(b2.data(), sz);
    for (int i = 0; i < sz; ++i) a2[i] = NTT2::mul(a2[i], b2[i]);
    ntt2(a2.data(), sz, true);

    ntt3(a.data(), sz);
    ntt3(b.data(), sz);
    for (int i = 0; i < sz; ++i) a[i] = NTT3::mul(a[i], b[i]);
    ntt3(a.data(), sz, true);

    a.resize(sa + sb - 1);
    for (size_t i = 0; i < a.size(); ++i)
        a[i] = superBigCRT(a1[i], a2[i], a[i], mod);
}

// \Pi_{l <= i < r} (x + a[i])
vector<int> getPoly(int l, int r, int mod) {
    if (r - l == 1) {
        return {l, 1};
    }
    int m = (l + r) >> 1;
    auto lhs = getPoly(l, m, mod);
    auto rhs = getPoly(m, r, mod);
    conv(lhs, rhs, mod);
    return lhs;
}

using ll = int64_t;

 int fac[maxn], ifac[maxn], inv[maxn];
 int C(int64_t n, int64_t k, int p) {
     if (k < 0 || n < k)
         return 0;
     if (n < p && k < p) {
         return 1LL * fac[n] * ifac[k] * ifac[n-k] % p;
     }
     return 1LL * C(n % p, k % p, p) * C(n / p, k / p, p) % p;
 }
 
 map<tuple<int64_t,int>, int> mp;
 int f(int64_t n, int k, int p) {
     if (k > n) return 0;
     if (mp.count({ n, k })) {
         return mp[{n,k}];
     }
     int64_t res = 0;
     for (int i = p*(k/p); i <= k; i++) {
         if ((n - i) % 2 == 0)
             res += C(n, i, p);
         else
             res -= C(n, i, p);
         if (res >= p) res -= p;
         if (res < 0) res += p;
     }
     mp[{n,k}] = res;
     assert(mp.size() <= 100);
     return res;
 }
 
 int solve(ll n, ll pre, int64_t p) {
     pre -= n / p;
     if (pre < 0)
         return 0;
     int r = n % p;
     int64_t ans = 0;
     auto coef = getPoly(0, r, p);
     for (int k = 0; k <= r; k++) {
         if (pre - k < 0) continue;
         ans += 1LL * coef[k] * f(n / p, (pre - k) / (p - 1), p);
         ans %= p;
     }
     return ans;
 }

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    ll n, l, r;
    int p;
    cin >> n >> l >> r >> p;
    fac[0] = ifac[0] = 1; inv[1] = 1;
    for (int i = 2; i < p; i++)
        inv[i] = 1LL * inv[p % i] * (p - p / i) % p;
    for (int i = 1; i < p; i++) {
        fac[i] = 1LL * fac[i-1] * i % p;
        ifac[i] = 1LL * ifac[i-1] * inv[i] % p;
    }
    cout << (solve(n, r, p) - solve(n, l - 1, p) + p) % p << '\n';
    int x; cin >> x;
    return 0;
}

詳細信息

Test #1:

score: 100
Accepted
time: 33ms
memory: 28168kb

input:

4 1 4 5

output:

4

result:

ok "4"

Test #2:

score: 0
Accepted
time: 30ms
memory: 28164kb

input:

6 5 5 29

output:

15

result:

ok "15"

Test #3:

score: 0
Accepted
time: 58ms
memory: 40028kb

input:

1000 685 975 999983

output:

482808

result:

ok "482808"

Test #4:

score: 0
Accepted
time: 30ms
memory: 28276kb

input:

8 7 8 7

output:

1

result:

ok "1"

Test #5:

score: -100
Memory Limit Exceeded

input:

6 4 6 3

output:


result: