QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#462776#7087. Counting PolygonsteraqqqML 0ms0kbC++146.5kb2024-07-04 03:55:242024-07-04 03:55:25

Judging History

你现在查看的是最新测评结果

  • [2024-07-04 03:55:25]
  • 评测
  • 测评结果:ML
  • 用时:0ms
  • 内存:0kb
  • [2024-07-04 03:55:24]
  • 提交

answer

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

constexpr int MOD = 1'000'000'007;
constexpr int N = 20'000'228;
constexpr int M = N;

int minp[N+1], fact[M+1], rfct[M+1], phi[N+1];
int rev(int a, int m = MOD) {
    return a == 1 ? 1 : m - (ll)m*rev(m%a,a)/a;
}
int c(int n, int k) {
    return 0 <= k && k <= n ? (ll)fact[n]*rfct[k]%MOD*rfct[n-k]%MOD : 0;
}

void mod_add(int& a, int b) {
    a += b;
    if (a >= MOD) a -= MOD;
}

void mod_sub(int& a, int b) {
    a -= b;
    if (a < 0) a += MOD;
}


// cnt(65, 14, 2) != 131445433 (error!)
int cnt(int n, int mp, int sn) {
    if (n < 0) return 0;
    if (sn == 2) {
        if (n % 2 == 1) {
            return (ll)2*cnt(n+1, mp+2, 0)%MOD;
        } else {
            return (cnt(n, mp+2, 0) + cnt(n+2, mp+2, 0))%MOD;
        }
    }
    if (sn) {
        int a = cnt(n, mp+1, sn-1);
        mod_add(a, cnt(n+1, mp+1, sn-1));
        return a;
    } else {
        if (!n && !mp) return 1;
        if (!n) return 0;
        if (n % 2 == 1) return 0;
        const int a = c(n/2-1, mp-1);
        // cerr << "n=" << n << ", mp=" << mp << ", a=" << a << endl;
        // cerr << c(33, 15) << endl;
        return a;
    }
}

int solve(int n, int m) {
    vector<int> divs{{1}};
    for (int j = m; j != 1; ) {
        const int p = minp[j], s = divs.size();
        // cout << j << ", p=" << p << endl;
        int k = 0;
        while (j % p == 0) j /= p, ++k;
        for (int i = 0; i < s*k; ++i)
            divs.push_back(divs[i]*p);
    }

    int ans = 0;
    for (int d : divs) {
        if (d == m) {
            mod_add(ans, c(n-1, m-1));
            ans = (ans + (ll)(MOD - m)*c(n/2, m-1))%MOD;
            // cout << "d=" << d << ": " << ans << " " << 1 << endl;
        } else {
            const int cnt = m / d;
            if (n % cnt == 0)
                ans = (ans + (ll)c(n/cnt-1, d-1)*phi[cnt])%MOD;
            // cout << "d=" << d << ": " << ans << " " << phi[cnt] << endl;
        }
    }
    // cerr << ans << " half" << endl;

    if (m % 2) {
        ans = (ans + (ll)m*cnt(n, m/2, 1))%MOD;
        ans = (ans + (ll)(MOD-m)*cnt(n/2+1, m/2, 1))%MOD;
    } else {
        ans = (ans + (ll)(m/2)*cnt(n, m/2, 0))%MOD;
        ans = (ans + (ll)(m/2)*cnt(n, m/2-1, 2))%MOD;
        ans = (ans + (ll)(MOD-m)*cnt(n/2+1, m/2-1, 2))%MOD;
    }
    // cerr << ans << "!" << endl;

    ans = (ll)ans*rev(2*m)%MOD;
    if (ans < 0) ans += MOD;
    return ans;
}

int solve_brute(int n, int m, int max, vector<int>& st) {
    if (n < m) return 0;
    if (!m) {
        if (n != 0) return 0;
        auto v = st;
        for (int t = 2; t--; ) {
            reverse(v.begin(), v.end());
            for (int i = 0; i < st.size(); ++i) {
                rotate(v.begin(), v.begin()+1, v.end());
                if (v < st) return 0;
            }
        }
        // for (int x : st) cout << x << ' ';
        // cout << endl;
        return 1;
    }
    int ans = 0;
    st.push_back(0);
    for (int x = 1; x <= n && 2*x < max; ++x) {
        st.back() = x;
        ans += solve_brute(n-x, m-1, max, st);
    }
    st.pop_back();
    return ans;
}

int solve_brute(int n, int m) {
    vector<int> st;
    int ans = solve_brute(n, m, n, st);
    return ans;
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0);
    
    for (int i = 2; i <= N; ++i) {
        if (minp[i]) continue;
        for (int j = i; j <= N; j += i)
            if (!minp[j]) minp[j] = i;
    }

    phi[1] = 1;
    for (int i = 2; i <= N; ++i) {
        const int p = minp[i];
        const int j = i / p;
        if (p == minp[j]) phi[i] = phi[j] * p;
        else              phi[i] = phi[j] * (p-1);
    }

    fact[0] = 1;
    for (int i = 1; i <= M; ++i) fact[i] = (ll)fact[i-1]*i%MOD;
    rfct[M] = rev(fact[M]);
    for (int i = M; i >= 1; --i) rfct[i-1] = (ll)rfct[i]*i%MOD;


    // constexpr int N = 300;
    // vector<vector<int>> sex(N+1, vector<int>(N+1));
    // sex[0][0] = 1;
    // for (int n = 1; n <= N; ++n) {
    //     sex[n][0] = 1;
    //     for (int k = 1; k <= n; ++k)
    //         sex[n][k] = (sex[n-1][k-1] + sex[n-1][k])%MOD;
    // }
    // for (int k = 0; k <= N; ++k)
    // for (int n = 0; n <= N; ++n) {
    //     assert(sex[n][k] == c(n,k));
    // }

    // // exit(0);

    // vector<vector<int>> dp(N+1, vector<int>(N+1));
    // dp[0][0] = 1;
    // for (int i = 1; i <= N; ++i) {
    //     for (int s = 1; s <= N; ++s)
    //         for (int d = 1; d <= s; ++d)
    //             mod_add(dp[i][s], dp[i-1][s-d]);
    // }

    // for (int i = 1; i <= N; ++i) {
    //     for (int s = 1; s <= N; ++s) {
    //         if (c(s-1, i-1) != dp[i][s]) {
    //             cout << "k=" << i << ", s=" << s << endl;
    //             cout << "WTF " << dp[i][s] << " vs " << c(s-1, i-1) << endl;
    //         }
    //         assert(c(s-1, i-1) == dp[i][s]);
    //     }
    // }

    // for (int i = 0; i <= 10; ++i) {
    //     for (int j = 0; j <= 10; ++j)
    //         cout << dp[i][j] << ' ';
    //     cout << endl;
    // }

    // for (int s = 0; s <= N; ++s)
    // for (int d = 0; d <= 2; ++d) {
    //     for (int k = 0; k <= N; ++k) {
    //         int cc = 0;
    //         for (int u = 0; u <= s; u += 2)
    //             cc = (cc + (ll)dp[k][u/2]*dp[d][s-u])%MOD;
    //         if (d == 0 && s % 2 == 0) assert(cc == dp[k][s/2]);
    //         int pans = cnt(s, k, d);
    //         if (cc != pans) {
    //             cout << "WA" << endl;
    //             cout << s << " " << k << " " << d << endl;
    //             cout << "pans=" << pans << endl;
    //             cout << "jans=" << cc << endl;
    //             exit(0);
    //         }
    //         assert(cc == cnt(s, k, d));
    //     }
    // }

    // for (int n = 1; n <= 20; ++n) {
    //     // cerr << "m=" << m << endl;
    //     for (int m = 3; m <= n; ++m) {
    //         int pans = solve(n, m);
    //         int jans = solve_brute(n, m);
    //         if (pans != jans) {
    //             cout << "WA!" << endl;
    //             cout << n << " " << m << endl;
    //             cout << "pans=" << pans << endl;
    //             cout << "jans=" << jans << endl;
    //             exit(0);
    //         }
    //     }
    // }

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        int ans = solve(n, m);
        cout << ans << '\n';
    }
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Memory Limit Exceeded

input:

4
3 3
4 3
5 3
7 4

output:

1
0
1
3

result: