QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#853766#9732. Gathering Mushroomsucup-team6275#RE 0ms3656kbC++208.4kb2025-01-11 19:06:482025-01-11 19:06:49

Judging History

你现在查看的是最新测评结果

  • [2025-01-11 19:06:49]
  • 评测
  • 测评结果:RE
  • 用时:0ms
  • 内存:3656kb
  • [2025-01-11 19:06:48]
  • 提交

answer

#include <iostream>
#include <vector>
#include <array>
#include <string>
#include <algorithm>
#include <iomanip>
#include <map>
#include <deque>
#include <set>
#include <random>

using namespace std;
#define ll long long
const ll INF = 1e18;

void solve() {
    int n, k;
    cin >> n >> k;
    vector <int> type(n);
    vector <vector <int>> g(n);
    vector <vector <int>> tree(n);
    vector <int> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> type[i];
        --type[i];
    }
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        --a[i];
        g[a[i]].push_back(i);
    }

    vector <int> par(n, -1);
    vector <int> num_cycle(n, -1);
    vector <int> is_on_cycle(n, -1);
    vector <int> cycle_pos(n, -1);
    vector <vector <int>> cycles;
    vector <int> used1(n);
    vector <int> gl(n);
    vector <int> tin(n);
    vector <int> bp(n, -1);
    vector <int> tout(n);
    int cur_time = 0;

    auto dfs = [&](auto&& dfs, int v, int p) -> void {
        used1[v] = 1;
        tin[v] = cur_time++;
        if (p != -1) {
            par[v] = p;
            is_on_cycle[v] = 0;
            num_cycle[v] = num_cycle[p];
            cycle_pos[v] = cycle_pos[p];
            tree[p].push_back(v);
            gl[v] = gl[p] + 1;

            if (gl[par[v]] - gl[bp[par[v]]] == gl[bp[par[v]]] - gl[bp[bp[par[v]]]]) bp[v] = bp[bp[par[v]]];
            else bp[v] = par[v];
        } else {
            bp[v] = v;
        }

        for (int i : g[v]) {
            if (is_on_cycle[i] == 1) continue;
            dfs(dfs, i, v);
        }

        tout[v] = cur_time++;
    };

    for (int i = 0; i < n; ++i) {
        if (used1[i]) continue;
        int cur = i;
        vector <int> st;

        while (!used1[cur]) {
            st.push_back(cur);
            used1[cur] = 1;
            cur = a[cur];
        }

        vector <int> cycle;
        while (true) {
            cycle.push_back(st.back());
            if (st.back() == cur) break;
            st.pop_back();
        }
        reverse(cycle.begin(), cycle.end());

        int cycle_len = cycle.size();
        cycles.push_back(cycle);
        for (int j = 0; j < cycle_len; ++j) {
            is_on_cycle[cycle[j]] = 1;
            num_cycle[cycle[j]] = cycles.size() - 1;
            cycle_pos[cycle[j]] = j;
        }

        for (int j : cycle) {
            dfs(dfs, j, -1);
        }
    }

    vector <vector <int>> all_such_col(n);
    for (int i = 0; i < n; ++i) {
        all_such_col[type[i]].push_back(i);
    }

    auto parent = [&](int v, int k) {
        while (k) {
            int delta = gl[v] - gl[bp[v]];
            if (delta <= k) {
                k -= delta;
                v = bp[v];
            } else {
                k--;
                v = par[v];
            }
        }
        return v;
    };

    auto lca = [&](int x, int y) {
        if (gl[x] < gl[y]) swap(x, y);
        x = parent(x, gl[x] - gl[y]);
        if (x == y) return x;

        while (par[x] != par[y]) {
            if (bp[x] != bp[y]) {
                x = bp[x];
                y = bp[y];
            } else {
                x = par[x];
                y = par[y];
            }
        }

        return par[x];
    };

    auto get_dist = [&](int len, int from, int to) {
        if (from <= to) return to - from;
        return len - from + to;
    };

    auto check_subtree = [&](int root, int son) {
        return tin[son] >= tin[root] && tin[son] <= tout[root];
    };

    vector <ll> opt(n, INF);
    map <pair <int, int>, vector <int>> poses;
    //on_cycle
    for (int col = 0; col < n; ++col) {

        for (int i : all_such_col[col]) {
            if (is_on_cycle[i]) {
                poses[make_pair(col, num_cycle[i])].push_back(cycle_pos[i]);
            }
        }
    }
    for (auto& i : poses) {
        sort(i.second.begin(), i.second.end());
        for (int j = 0; j < i.second.size(); ++j) {
            int ln = (int)cycles[i.first.second].size();
            int ful = (k - 1) / (int)i.second.size();
            int mem = (k - 1) % (int)i.second.size();

            ll moves = (ll)ln * (ll)ful + get_dist(ln, i.second[j], i.second[(j + mem) % ln]);
            opt[cycles[i.first.second][i.second[j]]] = moves;
        }
    }
    //in_tree
    for (int col = 0; col < n; ++col) {
        map <pair <int, int>, vector <int>> flex;

        for (int i : all_such_col[col]) {
            flex[make_pair(num_cycle[i], cycle_pos[i])].push_back(i);
        }

        for (auto& i : flex) {
            vector <int> ver = i.second;

            sort(ver.begin(), ver.end(), [&](int x, int y) {
                return tin[x] < tin[y];
            });

            int prev_len = ver.size();
            for (int j = 0; j < prev_len - 1; ++j) {
                ver.push_back(lca(ver[j], ver[j + 1]));
            }

            sort(ver.begin(), ver.end(), [&](int x, int y) {
                return tin[x] < tin[y];
            });

            ver.resize(unique(ver.begin(), ver.end()) - ver.begin());
            vector <int> st_all;
            vector <int> st_my_col;

            for (int j : ver) {
                while (!st_all.empty() && !check_subtree(st_all.back(), j)) {
                    if (!st_my_col.empty() && st_my_col.back() == st_all.back() && !is_on_cycle[st_all.back()]) st_my_col.pop_back();
                    st_all.pop_back();
                }

                st_all.push_back(j);
                if (type[j] == col && !is_on_cycle[j]) {
                    st_my_col.push_back(j);
                    if (k <= st_my_col.size()) {
                        opt[j] = gl[j] - gl[st_my_col[st_my_col.size() - k]];
                    } else {
                        ll cur = gl[j];
                        int cycle = num_cycle[j];
                        int temp_k = k - (int)st_my_col.size();
                        vector <int>& ps = poses[make_pair(col, cycle)];
                        if (ps.empty()) continue;
                        int ln = cycles[cycle].size();

                        int my_pos = cycle_pos[j];
                        auto it = lower_bound(ps.begin(), ps.end(), my_pos);

                        if (it != ps.end()) {
                            temp_k--;
                            cur += *it - my_pos;
                            my_pos = *it;
                        } else {
                            temp_k--;
                            cur += ln - my_pos + ps[0];
                            my_pos = ps[0];
                        }

                        my_pos = lower_bound(ps.begin(), ps.end(), my_pos) - ps.begin();

                        int ful = temp_k / (int)ps.size();
                        cur += 1ll * ful * ln;
                        temp_k %= ful;
                        cur += get_dist(ln, ps[my_pos], ps[(my_pos + temp_k) % (int)ps.size()]);
                        opt[j] = cur;
                    }
                }
            }
        }
    }

    vector <pair <ll, int>> res(n);
    for (int i = 0; i < n; ++i) {
        res[i] = make_pair(opt[i], type[i]);
    }

    auto recalc = [&](pair <ll, int>& cur_ans, pair <ll, int> other_ans) {
        if (other_ans.first + 1 < cur_ans.first) {
            cur_ans.first = other_ans.first + 1;
            cur_ans.second = other_ans.second;
        }
    };

    for (int i = 0; i < cycles.size(); ++i) {
        int oper = 0;
        int ln = cycles[i].size();
        for (int j = 0; oper < 3 * ln; j = (j + 1) % ln, oper++) {
            recalc(res[cycles[i][j]], res[cycles[i][(j + 1) % ln]]);
        }
    }

    vector <pair <int, int>> flex;
    for (int i = 0; i < n; ++i) flex.emplace_back(gl[i], i);
    sort(flex.begin(), flex.end());
    for (auto i : flex) {
        if (par[i.second] != -1) {
            recalc(res[i.second], res[par[i.second]]);
        }
    }

//    for (auto i : res) cerr << i.second + 1 << "\n";
    ll REAL_RES = 0;
    for (int i = 0; i < n; ++i) {
        REAL_RES += 1ll * (i + 1) * (res[i].second + 1);
    }
    cout << REAL_RES << "\n";
//    cerr << "---------\n";
}

signed main() {
    if (1) {
        ios_base::sync_with_stdio(false);
        cin.tie(nullptr);
    }
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}

/*
3
5 3
2 2 1 3 3
2 5 1 2 4
5 4
2 2 1 3 3
2 5 1 2 4
3 10
1 2 3
1 3 2
 */

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 0ms
memory: 3656kb

input:

3
5 3
2 2 1 3 3
2 5 1 2 4
5 4
2 2 1 3 3
2 5 1 2 4
3 10
1 2 3
1 3 2

output:

41
45
14

result:

ok 3 lines

Test #2:

score: -100
Runtime Error

input:

6000
19 48
18 19 18 19 11 9 15 19 12 18 11 18 9 18 9 18 19 11 15
12 14 18 8 1 3 19 5 13 14 15 2 14 5 19 2 19 12 9
15 23
3 1 1 3 6 1 4 1 1 6 6 4 12 4 6
14 1 8 8 6 6 12 14 6 8 5 7 14 2 5
9 140979583
4 5 8 9 2 7 6 8 2
8 9 4 6 9 2 4 7 8
4 976357580
2 3 1 3
2 1 1 4
6 508962809
4 3 4 3 4 4
4 5 4 5 5 6
13 ...

output:


result: