QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#462776 | #7087. Counting Polygons | teraqqq | ML | 0ms | 0kb | C++14 | 6.5kb | 2024-07-04 03:55:24 | 2024-07-04 03:55:25 |
answer
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr int MOD = 1'000'000'007;
constexpr int N = 20'000'228;
constexpr int M = N;
int minp[N+1], fact[M+1], rfct[M+1], phi[N+1];
int rev(int a, int m = MOD) {
return a == 1 ? 1 : m - (ll)m*rev(m%a,a)/a;
}
int c(int n, int k) {
return 0 <= k && k <= n ? (ll)fact[n]*rfct[k]%MOD*rfct[n-k]%MOD : 0;
}
void mod_add(int& a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
void mod_sub(int& a, int b) {
a -= b;
if (a < 0) a += MOD;
}
// cnt(65, 14, 2) != 131445433 (error!)
int cnt(int n, int mp, int sn) {
if (n < 0) return 0;
if (sn == 2) {
if (n % 2 == 1) {
return (ll)2*cnt(n+1, mp+2, 0)%MOD;
} else {
return (cnt(n, mp+2, 0) + cnt(n+2, mp+2, 0))%MOD;
}
}
if (sn) {
int a = cnt(n, mp+1, sn-1);
mod_add(a, cnt(n+1, mp+1, sn-1));
return a;
} else {
if (!n && !mp) return 1;
if (!n) return 0;
if (n % 2 == 1) return 0;
const int a = c(n/2-1, mp-1);
// cerr << "n=" << n << ", mp=" << mp << ", a=" << a << endl;
// cerr << c(33, 15) << endl;
return a;
}
}
int solve(int n, int m) {
vector<int> divs{{1}};
for (int j = m; j != 1; ) {
const int p = minp[j], s = divs.size();
// cout << j << ", p=" << p << endl;
int k = 0;
while (j % p == 0) j /= p, ++k;
for (int i = 0; i < s*k; ++i)
divs.push_back(divs[i]*p);
}
int ans = 0;
for (int d : divs) {
if (d == m) {
mod_add(ans, c(n-1, m-1));
ans = (ans + (ll)(MOD - m)*c(n/2, m-1))%MOD;
// cout << "d=" << d << ": " << ans << " " << 1 << endl;
} else {
const int cnt = m / d;
if (n % cnt == 0)
ans = (ans + (ll)c(n/cnt-1, d-1)*phi[cnt])%MOD;
// cout << "d=" << d << ": " << ans << " " << phi[cnt] << endl;
}
}
// cerr << ans << " half" << endl;
if (m % 2) {
ans = (ans + (ll)m*cnt(n, m/2, 1))%MOD;
ans = (ans + (ll)(MOD-m)*cnt(n/2+1, m/2, 1))%MOD;
} else {
ans = (ans + (ll)(m/2)*cnt(n, m/2, 0))%MOD;
ans = (ans + (ll)(m/2)*cnt(n, m/2-1, 2))%MOD;
ans = (ans + (ll)(MOD-m)*cnt(n/2+1, m/2-1, 2))%MOD;
}
// cerr << ans << "!" << endl;
ans = (ll)ans*rev(2*m)%MOD;
if (ans < 0) ans += MOD;
return ans;
}
int solve_brute(int n, int m, int max, vector<int>& st) {
if (n < m) return 0;
if (!m) {
if (n != 0) return 0;
auto v = st;
for (int t = 2; t--; ) {
reverse(v.begin(), v.end());
for (int i = 0; i < st.size(); ++i) {
rotate(v.begin(), v.begin()+1, v.end());
if (v < st) return 0;
}
}
// for (int x : st) cout << x << ' ';
// cout << endl;
return 1;
}
int ans = 0;
st.push_back(0);
for (int x = 1; x <= n && 2*x < max; ++x) {
st.back() = x;
ans += solve_brute(n-x, m-1, max, st);
}
st.pop_back();
return ans;
}
int solve_brute(int n, int m) {
vector<int> st;
int ans = solve_brute(n, m, n, st);
return ans;
}
int main() {
ios::sync_with_stdio(0); cin.tie(0);
for (int i = 2; i <= N; ++i) {
if (minp[i]) continue;
for (int j = i; j <= N; j += i)
if (!minp[j]) minp[j] = i;
}
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
const int p = minp[i];
const int j = i / p;
if (p == minp[j]) phi[i] = phi[j] * p;
else phi[i] = phi[j] * (p-1);
}
fact[0] = 1;
for (int i = 1; i <= M; ++i) fact[i] = (ll)fact[i-1]*i%MOD;
rfct[M] = rev(fact[M]);
for (int i = M; i >= 1; --i) rfct[i-1] = (ll)rfct[i]*i%MOD;
// constexpr int N = 300;
// vector<vector<int>> sex(N+1, vector<int>(N+1));
// sex[0][0] = 1;
// for (int n = 1; n <= N; ++n) {
// sex[n][0] = 1;
// for (int k = 1; k <= n; ++k)
// sex[n][k] = (sex[n-1][k-1] + sex[n-1][k])%MOD;
// }
// for (int k = 0; k <= N; ++k)
// for (int n = 0; n <= N; ++n) {
// assert(sex[n][k] == c(n,k));
// }
// // exit(0);
// vector<vector<int>> dp(N+1, vector<int>(N+1));
// dp[0][0] = 1;
// for (int i = 1; i <= N; ++i) {
// for (int s = 1; s <= N; ++s)
// for (int d = 1; d <= s; ++d)
// mod_add(dp[i][s], dp[i-1][s-d]);
// }
// for (int i = 1; i <= N; ++i) {
// for (int s = 1; s <= N; ++s) {
// if (c(s-1, i-1) != dp[i][s]) {
// cout << "k=" << i << ", s=" << s << endl;
// cout << "WTF " << dp[i][s] << " vs " << c(s-1, i-1) << endl;
// }
// assert(c(s-1, i-1) == dp[i][s]);
// }
// }
// for (int i = 0; i <= 10; ++i) {
// for (int j = 0; j <= 10; ++j)
// cout << dp[i][j] << ' ';
// cout << endl;
// }
// for (int s = 0; s <= N; ++s)
// for (int d = 0; d <= 2; ++d) {
// for (int k = 0; k <= N; ++k) {
// int cc = 0;
// for (int u = 0; u <= s; u += 2)
// cc = (cc + (ll)dp[k][u/2]*dp[d][s-u])%MOD;
// if (d == 0 && s % 2 == 0) assert(cc == dp[k][s/2]);
// int pans = cnt(s, k, d);
// if (cc != pans) {
// cout << "WA" << endl;
// cout << s << " " << k << " " << d << endl;
// cout << "pans=" << pans << endl;
// cout << "jans=" << cc << endl;
// exit(0);
// }
// assert(cc == cnt(s, k, d));
// }
// }
// for (int n = 1; n <= 20; ++n) {
// // cerr << "m=" << m << endl;
// for (int m = 3; m <= n; ++m) {
// int pans = solve(n, m);
// int jans = solve_brute(n, m);
// if (pans != jans) {
// cout << "WA!" << endl;
// cout << n << " " << m << endl;
// cout << "pans=" << pans << endl;
// cout << "jans=" << jans << endl;
// exit(0);
// }
// }
// }
int t; cin >> t;
while (t--) {
int n, m; cin >> n >> m;
int ans = solve(n, m);
cout << ans << '\n';
}
}
詳細信息
Test #1:
score: 0
Memory Limit Exceeded
input:
4 3 3 4 3 5 3 7 4
output:
1 0 1 3