QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#51100#1285. Stirling Numberckiseki#ML 54ms27632kbC++5.0kb2022-09-30 20:08:072022-09-30 20:08:07

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2022-09-30 20:08:07]
  • 评测
  • 测评结果:ML
  • 用时:54ms
  • 内存:27632kb
  • [2022-09-30 20:08:07]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

template <int mod, int G, int maxn>
class NTT {
private:
    static_assert(maxn == (maxn & (-maxn)));
    static_assert((mod - 1) % maxn == 0);
    int roots[maxn];
public:
    constexpr static int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
    constexpr static int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
    constexpr static int mul(int64_t a, int64_t b) { return static_cast<int>(a * b % mod); }
    constexpr static int qpow(int a, int64_t k) {
        int r = 1 % mod;
        while (k) {
            if (k & 1) r = mul(r, a);
            k >>= 1; a = mul(a, a);
        }
        return r;
    }
    constexpr static int minv(int a) { return qpow(a, mod - 2); }
    constexpr NTT() : roots{} {
        int r = qpow(G, (mod - 1) / maxn);
        for (int i = maxn / 2; i; i /= 2) {
            roots[i] = 1 % mod;
            for (int j = 1; j < i; ++j)
                roots[i + j] = mul(roots[i + j - 1], r);
            r = mul(r, r);
        }
    }
    void operator()(int F[], int n, bool inv = false) {
        for (int i = 0, j = 0; i < n; ++i) {
            if (i < j) swap(F[i], F[j]);
            for (int k = n / 2; (j ^= k) < k; k /= 2);
        }
        for (int s = 1; s < n; s *= 2) {
            for (int i = 0; i < n; i += s * 2) {
                for (int j = 0; j < s; ++j) {
                    int a = F[i + j];
                    int b = mul(F[i + j + s], roots[s + j]);
                    F[i + j] = add(a, b);
                    F[i + j + s] = sub(a, b);
                }
            }
        }
        if (inv) {
            int invn = minv(n);
            for (int i = 0; i < n; ++i)
                F[i] = mul(F[i], invn);
            reverse(F + 1, F + n);
        }
    }
};

static constexpr int maxn = 1 << 20;

static constexpr int M1 = 985661441;
static constexpr int M2 = 998244353;
static constexpr int M3 = 1004535809;

using NTT1 = NTT<M1, 3, maxn>;
using NTT2 = NTT<M2, 3, maxn>;
using NTT3 = NTT<M3, 3, maxn>;

NTT1 ntt1;
NTT2 ntt2;
NTT3 ntt3;

int superBigCRT(int64_t A, int64_t B, int64_t C, int mod) {
    static constexpr int64_t r12 = NTT2::qpow(M1, M2 - 2);
    static constexpr int64_t r13 = NTT3::qpow(M1, M3 - 2);
    static constexpr int64_t r23 = NTT3::qpow(M2, M3 - 2);
    const int64_t M1M2 = int64_t(M1) * M2 % mod;
    B = (B - A + M2) * r12 % M2;
    C = (C - A + M3) * r13 % M3;
    C = (C - B + M3) * r23 % M3;
    return (A + B * M1 + C * M1M2) % mod;
}

void conv(vector<int> &a, vector<int> &b, int mod) {
    const size_t sa = a.size(), sb = b.size();
    int sz = 1;
    while (sz < a.size() + b.size()) sz *= 2;
    a.resize(sz);
    b.resize(sz);

    auto a1 = a, a2 = a;
    auto b1 = b, b2 = b;

    ntt1(a1.data(), sz);
    ntt1(b1.data(), sz);
    for (int i = 0; i < sz; ++i) a1[i] = NTT1::mul(a1[i], b1[i]);
    ntt1(a1.data(), sz, true);
    
    ntt2(a2.data(), sz);
    ntt2(b2.data(), sz);
    for (int i = 0; i < sz; ++i) a2[i] = NTT2::mul(a2[i], b2[i]);
    ntt2(a2.data(), sz, true);

    ntt3(a.data(), sz);
    ntt3(b.data(), sz);
    for (int i = 0; i < sz; ++i) a[i] = NTT3::mul(a[i], b[i]);
    ntt3(a.data(), sz, true);

    a.resize(sa + sb - 1);
    for (size_t i = 0; i < a.size(); ++i)
        a[i] = superBigCRT(a1[i], a2[i], a[i], mod);
}

// \Pi_{l <= i < r} (x + a[i])
vector<int> getPoly(int l, int r, int mod) {
    if (r - l == 1) {
        return {l, 1};
    }
    int m = (l + r) >> 1;
    auto lhs = getPoly(l, m, mod);
    auto rhs = getPoly(m, r, mod);
    conv(lhs, rhs, mod);
    return lhs;
}

using ll = int64_t;

int fac[maxn], ifac[maxn], inv[maxn];
int C(int64_t n, int64_t k, int p) {
    if (k < 0 || n < k)
        return 0;
    if (n < p && k < p) {
        return 1LL * fac[n] * ifac[k] * ifac[n-k] % p;
    }
    return 1LL * C(n % p, k % p, p) * C(n / p, k / p, p) % p;
}

map<tuple<int64_t,int>, int> mp;
int f(int64_t n, int k, int p) {
    if (k > n) return 0;
    if (mp.count({ n, k })) {
        return mp[{n,k}];
    }
    int64_t res = 0;
    for (int i = p*(k/p); i <= k; i++) {
        if ((n - i) % 2 == 0)
            res += C(n, i, p);
        else
            res -= C(n, i, p);
        if (res >= p) res -= p;
        if (res < 0) res += p;
    }
    mp[{n,k}] = res;
    return res;
}

int solve(ll n, ll pre, int64_t p) {
    pre -= n / p;
    if (pre < 0)
        return 0;
    int r = n % p;
    int64_t ans = 0;
    auto coef = getPoly(0, r, p);
    for (int k = 0; k <= r; k++) {
        if (pre - k < 0) continue;
        ans += 1LL * coef[k] * f(n / p, (pre - k) / (p - 1), p);
        ans %= p;
    }
    return ans;
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    ll n, l, r;
    int p;
    cin >> n >> l >> r >> p;
    fac[0] = ifac[0] = 1; inv[1] = 1;
    for (int i = 2; i < p; i++)
        inv[i] = 1LL * inv[p % i] * (p - p / i) % p;
    for (int i = 1; i < p; i++) {
        fac[i] = 1LL * fac[i-1] * i % p;
        ifac[i] = 1LL * ifac[i-1] * inv[i] % p;
    }
    cout << (solve(n, r, p) - solve(n, l - 1, p) + p) % p << '\n';
    return 0;
}

詳細信息

Test #1:

score: 100
Accepted
time: 18ms
memory: 15932kb

input:

4 1 4 5

output:

4

result:

ok "4"

Test #2:

score: 0
Accepted
time: 15ms
memory: 15876kb

input:

6 5 5 29

output:

15

result:

ok "15"

Test #3:

score: 0
Accepted
time: 54ms
memory: 27632kb

input:

1000 685 975 999983

output:

482808

result:

ok "482808"

Test #4:

score: 0
Accepted
time: 11ms
memory: 15872kb

input:

8 7 8 7

output:

1

result:

ok "1"

Test #5:

score: -100
Memory Limit Exceeded

input:

6 4 6 3

output:


result: