QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#312126 | #6328. Many Products | PlentyOfPenalty# | WA | 18ms | 17224kb | C++20 | 9.6kb | 2024-01-23 13:57:30 | 2024-01-23 13:57:30 |
Judging History
answer
#include <bits/stdc++.h>
#define logvec(v) \
cerr << #v << ": "; \
for (int i = 0; i < v.size(); i++) \
cerr << v[i] << " \n"[i + 1 == v.size()];
#define all(x) begin(x), end(x)
using namespace std;
const int N = 2e5 + 9, K = 28, V = 1e6 + 9, mod = 998244353;
int n, t, a[N], fac[N], ifac[N], inv[N], f[N][K], g[N][K], h[K][K], dp[K];
bool vis[V];
long long m;
vector<int> pri;
vector<long long> uni;
vector<pair<long long, int>> p;
void upd(int &x, int y) {
x += y;
if (x >= mod) x -= mod;
}
int sub(int x, int y) {
x -= y;
return x < 0 ? x + mod : x;
}
int add(int x, int y) {
x += y;
return x >= mod ? x - mod : x;
}
int C(int n, int m) {
if (n < m) return 0;
return (long long)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int power(int a, int b) {
int s = 1;
for (; b; b >>= 1, a = (long long)a * a % mod)
if (b & 1) s = (long long)s * a % mod;
return s;
}
vector<int> rev, rt;
void getRevRoot(int n) {
int m = __lg(n); // log(n)/log(2)+1e-8;
rev.resize(n);
for (int i = 1; i < n; i++) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (m - 1);
}
static int len = 1;
if (len < n) {
rt.resize(n);
for (; len < n; len <<= 1) {
int uni = power(3, (mod - 1) / (len << 1));
rt[len] = 1;
for (int i = 1; i < len; i++) {
rt[i + len] = (long long)rt[i + len - 1] * uni % mod;
}
}
}
}
void ntt(vector<int> &f, int n) {
f.resize(n);
for (int i = 0; i < n; i++) {
if (i < rev[i]) swap(f[i], f[rev[i]]);
}
for (int len = 1; len < n; len *= 2) {
for (int i = 0; i < n; i += len * 2) {
for (int j = 0; j < len; j++) {
int x = f[i + j];
int y = (long long)f[i + j + len] * rt[j + len] % mod;
f[i + j] = add(x, y);
f[i + j + len] = sub(x, y);
}
}
}
}
vector<int> operator*(vector<int> f, vector<int> g) {
int n = 1, m = (int)(f.size() + g.size()) - 1;
while (n < m)
n <<= 1;
int invn = power(n, mod - 2);
getRevRoot(n), ntt(f, n), ntt(g, n);
for (int i = 0; i < n; i++)
f[i] = (long long)f[i] * g[i] % mod;
reverse(f.begin() + 1, f.end()), ntt(f, n);
f.resize(m);
for (int i = 0; i < m; i++)
f[i] = (long long)f[i] * invn % mod;
return f;
}
vector<pair<long long, int>> factor(long long x) {
vector<pair<long long, int>> fac;
for (int p : pri)
if (p * p <= x) {
if (x % p == 0) {
fac.emplace_back(p, 0);
while (x % p == 0) {
x /= p;
fac.back().second++;
}
}
} else {
break;
}
if (x > 1) fac.emplace_back(x, 1);
return fac;
}
void dfs(int u, long long s) {
if (u == p.size()) {
uni.emplace_back(s);
return;
}
for (int i = 0; i <= p[u].second; i++) {
dfs(u + 1, s);
s *= p[u].first;
}
}
int find(long long x) {
if (x > m) return -1;
int k = lower_bound(all(uni), x) - uni.begin();
return x == uni[k] ? k : -1;
}
vector<int> solve(int l, int r) {
if (l == r) {
return {1, a[l]};
}
int mid = (l + r) >> 1;
vector<int> lv = solve(l, mid);
vector<int> rv = solve(mid + 1, r);
vector<int> v(min((size_t)K, lv.size() + rv.size() - 1));
for (int i = 0; i < lv.size(); i++)
for (int j = 0; j < rv.size() && i + j < v.size(); j++) {
v[i + j] = (v[i + j] + (long long)lv[i] * rv[j]) % mod;
}
return v;
}
int main() {
#ifdef popteam
// freopen("L.in", "r", stdin);
freopen("L2.in", "r", stdin);
#endif
inv[0] = inv[1] = fac[0] = ifac[0] = 1;
for (int i = 2; i < N; i++) {
inv[i] = (long long)(mod - mod / i) * inv[mod % i] % mod;
}
for (int i = 1; i < N; i++) {
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[i] = (long long)ifac[i - 1] * inv[i] % mod;
}
for (int i = 2; i < V; i++) {
if (!vis[i]) pri.push_back(i);
for (int j = i; j < V; j += i)
vis[j] = true;
}
cin.tie(0)->sync_with_stdio(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (++a[i] == mod) a[i] = 0;
}
p = factor(m);
dfs(0, 1);
sort(all(uni));
logvec(uni);
int zero = 0, all = 1;
for (int i = 1; i <= n; i++)
if (a[i] == 0) {
swap(a[++zero], a[i]);
} else {
all = (long long)all * a[i] % mod;
a[i] = power(a[i], mod - 2);
}
vector<int> p = zero < n ? solve(zero + 1, n) : vector<int>{};
for (int i = 0; i < V - zero; i++)
dp[i + zero] = i < p.size() ? p[i] : 0;
for (int i = 0; i < V; i++)
dp[i] = (long long)dp[i] * all % mod;
f[0][0] = g[0][0] = 1;
for (int i = 1; i < uni.size(); i++) {
int j = upper_bound(all(uni), m / uni[i]) - uni.begin() - 1;
for (; j >= 0; j--)
if (m % (uni[i] * uni[j]) == 0) {
int k = find(uni[i] * uni[j]);
assert(k != -1);
long long t = m / uni[j];
for (int y = 1; y < K; y++) {
if (t % uni[i]) break;
t /= uni[i];
fprintf(stderr, ">> i=%d j=%d k=%d y=%d\n", i, j, k, y);
for (int x = y; x < K; x++) {
upd(f[k][x], f[j][x - y]);
upd(g[k][x], (long long)g[j][x - y] * power(uni[i] - 1, y) % mod);
}
}
}
}
for (int i = 0; i < uni.size(); i++)
for (int j = 0; j < K; j++)
if (f[i][j] || g[i][j]) fprintf(stderr, "i=%d j=%d >> f=%d g=%d\n", i, j, f[i][j], g[i][j]);
for (int x = 0; x < K; x++)
for (int y = 0; x + y < K; y++) {
for (int i = 0; i < uni.size(); i++) {
int j = find(m / uni[i]);
assert(j != -1);
if (f[i][x] && g[j][y]) {
upd(h[x][y], (long long)f[i][x] * g[j][y] % mod);
fprintf(stderr, "f[%d][%d](%d) * g[%d][%d](%d) -> %d\n", i, x, f[i][x], j, y, g[j][y], h[x][y]);
}
}
}
int ans = 0;
for (int x = 0; x < K; x++)
for (int y = 0; x + y < K && x + y <= n; y++)
if (h[x][y] && dp[y]) {
fprintf(stderr, "x=%d y=%d h=%d C=%d dp=%d\n", x, y, h[x][y], C(n - y, x), dp[y]);
upd(ans, (long long)h[x][y] * C(n - y, x) % mod * dp[y] % mod);
}
cout << ans << endl;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 8ms
memory: 16948kb
input:
2 3 0 1
output:
10
result:
ok 1 number(s): "10"
Test #2:
score: 0
Accepted
time: 18ms
memory: 14368kb
input:
5 1 0 1 2 3 4
output:
120
result:
ok 1 number(s): "120"
Test #3:
score: -100
Wrong Answer
time: 10ms
memory: 17224kb
input:
10 314159265358 0 1 2 3 4 5 6 7 8 9
output:
-606013828
result:
wrong answer 1st numbers differ - expected: '658270849', found: '-606013828'