QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#312126#6328. Many ProductsPlentyOfPenalty#WA 18ms17224kbC++209.6kb2024-01-23 13:57:302024-01-23 13:57:30

Judging History

你现在查看的是最新测评结果

  • [2024-01-23 13:57:30]
  • 评测
  • 测评结果:WA
  • 用时:18ms
  • 内存:17224kb
  • [2024-01-23 13:57:30]
  • 提交

answer

#include <bits/stdc++.h>
#define logvec(v
    cerr << #v
    for (int i = 0; i < v.size(); i++)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 \
        cerr << v[i] << " \n"[i + 1 == v.size()];
#define all(x) begin(x), end(x)
using namespace std;
const int N = 2e5 + 9, K = 28, V = 1e6 + 9, mod = 998244353;
int n, t, a[N], fac[N], ifac[N], inv[N], f[N][K], g[N][K], h[K][K], dp[K];
bool vis[V];
long long m;
vector<int> pri;
vector<long long> uni;
vector<pair<long long, int>> p;

void upd(int &x, int y) {
    x += y;
    if (x >= mod) x -= mod;
}
int sub(int x, int y) {
    x -= y;
    return x < 0 ? x + mod : x;
}
int add(int x, int y) {
    x += y;
    return x >= mod ? x - mod : x;
}
int C(int n, int m) {
    if (n < m) return 0;
    return (long long)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int power(int a, int b) {
    int s = 1;
    for (; b; b >>= 1, a = (long long)a * a % mod)
        if (b & 1) s = (long long)s * a % mod;
    return s;
}
vector<int> rev, rt;
void getRevRoot(int n) {
    int m = __lg(n); // log(n)/log(2)+1e-8;
    rev.resize(n);
    for (int i = 1; i < n; i++) {
        rev[i] = rev[i >> 1] >> 1 | (i & 1) << (m - 1);
    }
    static int len = 1;
    if (len < n) {
        rt.resize(n);
        for (; len < n; len <<= 1) {
            int uni = power(3, (mod - 1) / (len << 1));
            rt[len] = 1;
            for (int i = 1; i < len; i++) {
                rt[i + len] = (long long)rt[i + len - 1] * uni % mod;
            }
        }
    }
}
void ntt(vector<int> &f, int n) {
    f.resize(n);
    for (int i = 0; i < n; i++) {
        if (i < rev[i]) swap(f[i], f[rev[i]]);
    }
    for (int len = 1; len < n; len *= 2) {
        for (int i = 0; i < n; i += len * 2) {
            for (int j = 0; j < len; j++) {
                int x = f[i + j];
                int y = (long long)f[i + j + len] * rt[j + len] % mod;
                f[i + j] = add(x, y);
                f[i + j + len] = sub(x, y);
            }
        }
    }
}
vector<int> operator*(vector<int> f, vector<int> g) {
    int n = 1, m = (int)(f.size() + g.size()) - 1;
    while (n < m)
        n <<= 1;
    int invn = power(n, mod - 2);
    getRevRoot(n), ntt(f, n), ntt(g, n);
    for (int i = 0; i < n; i++)
        f[i] = (long long)f[i] * g[i] % mod;
    reverse(f.begin() + 1, f.end()), ntt(f, n);
    f.resize(m);
    for (int i = 0; i < m; i++)
        f[i] = (long long)f[i] * invn % mod;
    return f;
}

vector<pair<long long, int>> factor(long long x) {
    vector<pair<long long, int>> fac;
    for (int p : pri)
        if (p * p <= x) {
            if (x % p == 0) {
                fac.emplace_back(p, 0);
                while (x % p == 0) {
                    x /= p;
                    fac.back().second++;
                }
            }
        } else {
            break;
        }
    if (x > 1) fac.emplace_back(x, 1);
    return fac;
}
void dfs(int u, long long s) {
    if (u == p.size()) {
        uni.emplace_back(s);
        return;
    }
    for (int i = 0; i <= p[u].second; i++) {
        dfs(u + 1, s);
        s *= p[u].first;
    }
}

int find(long long x) {
    if (x > m) return -1;
    int k = lower_bound(all(uni), x) - uni.begin();
    return x == uni[k] ? k : -1;
}
vector<int> solve(int l, int r) {
    if (l == r) {
        return {1, a[l]};
    }
    int mid = (l + r) >> 1;
    vector<int> lv = solve(l, mid);
    vector<int> rv = solve(mid + 1, r);
    vector<int> v(min((size_t)K, lv.size() + rv.size() - 1));
    for (int i = 0; i < lv.size(); i++)
        for (int j = 0; j < rv.size() && i + j < v.size(); j++) {
            v[i + j] = (v[i + j] + (long long)lv[i] * rv[j]) % mod;
        }
    return v;
}

int main() {
#ifdef popteam
    // freopen("L.in", "r", stdin);
    freopen("L2.in", "r", stdin);
#endif
    inv[0] = inv[1] = fac[0] = ifac[0] = 1;
    for (int i = 2; i < N; i++) {
        inv[i] = (long long)(mod - mod / i) * inv[mod % i] % mod;
    }
    for (int i = 1; i < N; i++) {
        fac[i] = (long long)fac[i - 1] * i % mod;
        ifac[i] = (long long)ifac[i - 1] * inv[i] % mod;
    }
    for (int i = 2; i < V; i++) {
        if (!vis[i]) pri.push_back(i);
        for (int j = i; j < V; j += i)
            vis[j] = true;
    }

    cin.tie(0)->sync_with_stdio(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        if (++a[i] == mod) a[i] = 0;
    }

    p = factor(m);
    dfs(0, 1);
    sort(all(uni));
    logvec(uni);

    int zero = 0, all = 1;
    for (int i = 1; i <= n; i++)
        if (a[i] == 0) {
            swap(a[++zero], a[i]);
        } else {
            all = (long long)all * a[i] % mod;
            a[i] = power(a[i], mod - 2);
        }
    vector<int> p = zero < n ? solve(zero + 1, n) : vector<int>{};
    for (int i = 0; i < V - zero; i++)
        dp[i + zero] = i < p.size() ? p[i] : 0;
    for (int i = 0; i < V; i++)
        dp[i] = (long long)dp[i] * all % mod;

    f[0][0] = g[0][0] = 1;
    for (int i = 1; i < uni.size(); i++) {
        int j = upper_bound(all(uni), m / uni[i]) - uni.begin() - 1;
        for (; j >= 0; j--)
            if (m % (uni[i] * uni[j]) == 0) {
                int k = find(uni[i] * uni[j]);
                assert(k != -1);
                long long t = m / uni[j];
                for (int y = 1; y < K; y++) {
                    if (t % uni[i]) break;
                    t /= uni[i];
                    fprintf(stderr, ">> i=%d j=%d k=%d y=%d\n", i, j, k, y);
                    for (int x = y; x < K; x++) {
                        upd(f[k][x], f[j][x - y]);
                        upd(g[k][x], (long long)g[j][x - y] * power(uni[i] - 1, y) % mod);
                    }
                }
            }
    }
    for (int i = 0; i < uni.size(); i++)
        for (int j = 0; j < K; j++)
            if (f[i][j] || g[i][j]) fprintf(stderr, "i=%d j=%d >> f=%d g=%d\n", i, j, f[i][j], g[i][j]);
    for (int x = 0; x < K; x++)
        for (int y = 0; x + y < K; y++) {
            for (int i = 0; i < uni.size(); i++) {
                int j = find(m / uni[i]);
                assert(j != -1);
                if (f[i][x] && g[j][y]) {
                    upd(h[x][y], (long long)f[i][x] * g[j][y] % mod);
                    fprintf(stderr, "f[%d][%d](%d) * g[%d][%d](%d) -> %d\n", i, x, f[i][x], j, y, g[j][y], h[x][y]);
                }
            }
        }
    int ans = 0;
    for (int x = 0; x < K; x++)
        for (int y = 0; x + y < K && x + y <= n; y++)
            if (h[x][y] && dp[y]) {
                fprintf(stderr, "x=%d y=%d h=%d C=%d dp=%d\n", x, y, h[x][y], C(n - y, x), dp[y]);
                upd(ans, (long long)h[x][y] * C(n - y, x) % mod * dp[y] % mod);
            }
    cout << ans << endl;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 8ms
memory: 16948kb

input:

2 3
0 1

output:

10

result:

ok 1 number(s): "10"

Test #2:

score: 0
Accepted
time: 18ms
memory: 14368kb

input:

5 1
0 1 2 3 4

output:

120

result:

ok 1 number(s): "120"

Test #3:

score: -100
Wrong Answer
time: 10ms
memory: 17224kb

input:

10 314159265358
0 1 2 3 4 5 6 7 8 9

output:

-606013828

result:

wrong answer 1st numbers differ - expected: '658270849', found: '-606013828'