QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#856820#9734. Identify Chorducup-team045TL 0ms0kbC++206.0kb2025-01-14 17:08:482025-01-14 17:08:50

Judging History

This is the latest submission verdict.

  • [2025-01-14 17:08:50]
  • Judged
  • Verdict: TL
  • Time: 0ms
  • Memory: 0kb
  • [2025-01-14 17:08:48]
  • Submitted

answer

#include<iostream>
#include<cstring>
#include<vector>
#include<array>
#include<set>
#include<algorithm>
using namespace std;
using LL = long long;

vector<vector<int> > get_cycle(vector<int> &a){
    const int n = (int)a.size() - 1;
    vector<vector<int> > cycles;
    vector<int> v(n + 1, -1);
    for(int i = 1; i <= n; i++){
        if (v[i] != -1) continue;
        int t = i;
        while(v[t] == -1){
            v[t] = i;
            t = a[t];
        }
        if (v[t] == i){
            vector<int> cycle{t};
            while(a[cycle.back()] != t){
                cycle.push_back(a[cycle.back()]);
            }
            cycles.push_back(cycle);
        }
    }
    return cycles;
}

int main(){

#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif

    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);

    int T;
    cin >> T;
    while(T--){
        int n, k;
        cin >> n >> k;
        vector<int> c(n + 1), a(n + 1);
        k--;
        for(int i = 1; i <= n; i++) cin >> c[i];
        vector<vector<int> > g(n + 1);
        for(int i = 1; i <= n; i++){
            cin >> a[i];
            g[a[i]].push_back(i);
            g[i].push_back(a[i]);
        } 
        vector<int> ans(n + 1);
        vector<LL> tim(n + 1, 1e18);
        vector<array<int, 30> > f(n + 1);
        vector<array<LL, 30> > sum(n + 1);
        vector<bool> v(n + 1), oncycle(n + 1);
        vector<vector<int> > pos(n + 1);
        vector<int> dep(n + 1);
        vector<int> last(n + 1);
        for(auto cycle : get_cycle(a)){
            for(auto x : cycle) oncycle[x] = true;
            const int m = cycle.size();
            set<int> s;
            for(int j = 0; j < m; j++){
                pos[c[cycle[j]]].push_back(j);
                s.insert(c[cycle[j]]);
            }
            for(int j = 0; j < m; j++){
                pos[c[cycle[j]]].push_back(j + m);
            }
            for(auto x : s){
                reverse(pos[x].begin(), pos[x].end());
            }

            vector<int> pt;

            for(int j = 0; j < m; j++){
                auto dfs = [&](auto &&dfs, int u, int fa) -> void {
                    pt.push_back(u);

                    if (last[c[u]]){
                        f[u][0] = last[c[u]];
                        sum[u][0] = dep[u] - dep[last[c[u]]];
                    }
                    else{
                        while(!pos[c[u]].empty() and pos[c[u]].back() <= j){
                            pos[c[u]].pop_back();
                        }
                        if (!pos[c[u]].empty()){
                            f[u][0] = cycle[pos[c[u]].back() % m];
                            sum[u][0] = dep[u] + pos[c[u]].back() - j;
                        }
                        else{
                            f[u][0] = 0;
                            sum[u][0] = 1e18;
                        }
                    }

                    // LL ss = 0;
                    // for(int i = 1; i <= 29; i++){
                    //     f[u][i] = f[f[u][i - 1]][i - 1];
                    //     sum[u][i] = sum[u][i - 1] + sum[f[u][i - 1]][i - 1];
                    // }
                    // {
                    //     int p = u;
                    //     for(int i = 0; i <= 29; i++){
                    //         if (k >> i & 1){
                    //             ss += sum[p][i];
                    //             p = f[p][i];
                    //         }
                    //     }
                    //     tim[u] = ss;
                    //     ans[u] = c[u];
                    // }

                    int bk = last[c[u]];
                    last[c[u]] = u;
                    
                    for(auto j : g[u]){
                        if (j == fa or oncycle[j]) continue;
                        dep[j] = dep[u] + 1;
                        ans[j] = ans[u];
                        tim[j] = tim[u] + 1;
                        dfs(dfs, j, u);
                    }
                    last[c[u]] = bk;
                };

                dfs(dfs, cycle[j], -1);
            }

            for(int i = 1; i <= 29; i++){
                for(auto u : pt){
                    f[u][i] = f[f[u][i - 1]][i - 1];
                    sum[u][i] = sum[u][i - 1] + sum[f[u][i - 1]][i - 1];
                }
            }

            for(auto u : pt){
                LL ss = 0;
                int p = u;
                for(int i = 0; i <= 29; i++){
                    if (k >> i & 1){
                        ss += sum[p][i];
                        p = f[p][i];
                    }
                }
                tim[u] = ss;
                ans[u] = c[p];
            }

            for(int i = 2 * m - 2; i >= 0; i--){
                int x = cycle[i % m], y = cycle[(i + 1) % m];
                if (tim[y] + 1 < tim[x]){
                    tim[x] = tim[y] + 1;
                    ans[x] = ans[y];
                }
            }

            for(int j = 0; j < m; j++){
                pos[c[cycle[j]]].clear();

                auto dfs = [&](auto &&dfs, int u, int fa) -> void {
                    for(auto j : g[u]){
                        if (j == fa or oncycle[j]) continue;
                        if (tim[u] + 1 < tim[j]){
                            tim[j] = tim[u] + 1;
                            ans[j] = ans[u];
                        }
                        dfs(dfs, j, u);
                    }
                };

                dfs(dfs, cycle[j], -1);
            }
            // for(auto x : cycle){
            //     cout << x << ' ';
            // }
            // cout << '\n';
        }
        // for(int i = 1; i <= n; i++){
        //     cout << ans[i] << ' ' << tim[i] << '\n';
        //     // cout << f[i][0] << ' ' << sum[i][0] << '\n';
        // }
        LL res = 0;
        for(int i = 1; i <= n; i++){
            res += 1LL * i * ans[i];
        }
        cout << res << '\n';
    }

}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Time Limit Exceeded

input:

2
6

output:


result: