QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#51480#4483. Count SetlqhsmashAC ✓7201ms60692kbC++4.4kb2022-10-02 15:00:332022-10-02 15:00:36

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2022-10-02 15:00:36]
  • 评测
  • 测评结果:AC
  • 用时:7201ms
  • 内存:60692kb
  • [2022-10-02 15:00:33]
  • 提交

answer

#include <bits/stdc++.h>
#define ll long long
using namespace std;

const int N = 5e5 + 50;
const int MOD = 998244353;
int fpow (int x, int p) {
    int res = 1;
    for (; p; p >>= 1, x = (ll)x * x % MOD) 
        if (p & 1) res = (ll)res * x % MOD;
    return res;
}
const int G = 3;
const int Gi = fpow (3, MOD - 2);
 
void ntt (vector<int>& f, int lim, vector<int> rev, int inv) {
    for (int i = 0; i < lim; i ++) 
        if (rev[i] > i) f[rev[i]] ^= f[i] ^= f[rev[i]] ^= f[i];
    for (int k = 1; k < lim; k <<= 1) {
        int gn = fpow (inv == 1 ? G : Gi, (MOD - 1) / (k << 1));
        for (int i = 0; i < lim; i += k << 1) {
            int g = 1;
            for (int j = 0; j < k; j ++, g = (ll)g * gn % MOD) {
                int nx = f[i + j], ny = (ll)g * f[i + j + k] % MOD;
                f[i + j] = nx + ny;
                f[i + j + k] = nx - ny;
                if (f[i + j] >= MOD) f[i + j] -= MOD;
                if (f[i + j + k] < 0) f[i + j + k] += MOD;
            }
        }
    }
    if (inv == -1) {
        inv = fpow (lim, MOD - 2);
        for (int i = 0; i < lim; i ++) f[i] = (ll)f[i] * inv % MOD;
    }
}

vector<int> mul (vector<int> a, vector<int> b) {
    int n = a.size (), m = b.size (), lim, bit;
    for (lim = 1, bit = 0; lim < n + m - 1; lim <<= 1) bit ++;
    vector<int> rev(lim, 0);
    a.resize (lim), b.resize (lim);
    for (int i = 0; i < lim; i ++) 
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    for (int i = n; i < lim; i ++) a[i] = 0;
    for (int i = m; i < lim; i ++) b[i] = 0;
    ntt (a, lim, rev, 1), ntt (b, lim, rev, 1);
    for (int i = 0; i < lim; i ++) a[i] = (ll)a[i] * b[i] % MOD;
    ntt (a, lim, rev, -1);
    return a.resize (n + m - 1), a;
}

int T = 1;
int n, k, a[N], fa[N], siz[N];
int fac[N], fnv[N];

int find (int x) {
    if (fa[x] == x) return x;
    return fa[x] = find (fa[x]);
}

void merge (int x, int y) {
    x = find (x), y = find (y);
    if (x == y) return ;
    fa[x] = y, siz[y] += siz[x];
}

void init () {
    fac[0] = 1;
    for (int i = 1; i < N; i ++) fac[i] = (ll)fac[i - 1] * i % MOD;
    fnv[N - 1] = fpow (fac[N - 1], MOD - 2);
    for (int i = N - 1; i > 0; i --) fnv[i - 1] = (ll)fnv[i] * i % MOD;
}

int C (int x, int y) {
    if (y < 0) return 0;
    if (y == 0) return 1;
    if (x < y) return 0;
    return fac[x] * (ll)fnv[y] % MOD * fnv[x - y] % MOD;
}

int cal (int x, int y) {
    if (y == 0) return 1;
    if (y * 2 > x) return 0;
    // printf("%d %d %d %d\n", x, y, C (x - y + 1, y), C (x - 4 - (y - 2) + 1, y - 2));
    return (C (x - y + 1, y) - (ll)C (x - 4 - (y - 2) + 1, y - 2) + MOD) % MOD;
}

vector<int> g[N];
int tot;
struct node { int x, id; };
bool operator < (node l, node r) {
    return l.x > r.x;
}

void solve () {
    scanf ("%d%d", &n, &k);
    for (int i = 1; i <= n; i ++) fa[i] = i, siz[i] = 1, g[i].resize (0);
    priority_queue<node> que;
    tot = 0;
    for (int i = 1; i <= n; i ++) {
        scanf ("%d", &a[i]);
        merge (a[i], i);
    }
    // cout << "cal = " << cal (1, 1) << endl;
    for (int i = 1; i <= n; i ++) {
        if (fa[i] == i) {
            tot ++;
            // cout << "siz = " << siz[i] << endl;
            for (int j = 0; j <= siz[i]; j ++) {
                // cout << cal (siz[i], j) << ' ';
                if (j > k) break;
                int val = cal (siz[i], j);
                if (val == 0) break;
                g[tot].push_back (val);
            }
            // cout << endl;
            que.push ({g[tot].size (), tot});
        }
    }
    // cerr << "siz = " << que.size () << endl;
    while (que.size () > 1) {
        node u = que.top (); que.pop ();
        // cerr << k << " sz = " << u.x << endl;
        node v = que.top (); que.pop ();
        g[u.id] = mul (g[u.id], g[v.id]);
        que.push ({g[u.id].size (), u.id});
    }
    node u = que.top (); que.pop ();
    // cerr << "id = " << k << ' ' << u.id << endl;
    // for (int x : g[u.id]) cerr << x << ' ';
    // cerr << endl;
    if (g[u.id].size () > k) printf("%d\n", g[u.id][k]);
    else printf("%d\n", 0);
}

int main() {
    // init ();
    // for (int i = 1; i <= 10; i ++) {
    //     for (int j = 1; j <= 10; j ++) {
    //         cal (i, j);
    //     }
    // }
    scanf ("%d", &T);
    init ();
    while (T --) {
        solve ();
    }
    return 0;
}

詳細信息

Test #1:

score: 100
Accepted
time: 7201ms
memory: 60692kb

input:

14
5 1
5 3 2 1 4
5 2
2 5 1 3 4
10 3
10 9 3 8 6 4 5 7 2 1
30 5
1 16 28 30 27 23 2 20 10 12 7 13 11 15 17 24 14 25 21 4 22 29 3 6 19 18 8 26 9 5
30 5
29 6 21 30 14 18 24 26 3 11 23 13 2 12 16 9 4 22 25 20 28 19 5 17 8 10 15 1 7 27
500000 200000
293510 102358 252396 467703 280403 93120 462332 442364 31...

output:

5
5
40
51129
51359
371836159
565197945
0
0
844811446
803690398
638630160
14371218
1

result:

ok 14 lines