QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#856801#9732. Gathering Mushroomsucup-team045WA 43ms3712kbC++206.0kb2025-01-14 16:50:132025-01-14 16:50:15

Judging History

你现在查看的是最新测评结果

  • [2025-01-14 16:50:15]
  • 评测
  • 测评结果:WA
  • 用时:43ms
  • 内存:3712kb
  • [2025-01-14 16:50:13]
  • 提交

answer

#include<iostream>
#include<cstring>
#include<vector>
#include<array>
#include<set>
#include<algorithm>
using namespace std;
using LL = long long;

int main(){

#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif

    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);

    int T;
    cin >> T;
    while(T--){
        int n, k;
        cin >> n >> k;
        vector<int> c(n + 1), a(n + 1);
        k--;
        for(int i = 1; i <= n; i++) cin >> c[i];
        vector<vector<int> > g(n + 1);
        for(int i = 1; i <= n; i++){
            cin >> a[i];
            g[a[i]].push_back(i);
            g[i].push_back(a[i]);
        } 
        if (k == 0){
            LL sum = 0;
            for(int i = 1; i <= n; i++) sum += a[i];
            cout << sum << '\n';
            continue;
        }
        vector<int> ans(n + 1);
        vector<LL> tim(n + 1, 1e18);
        vector<array<int, 20> > f(n + 1);
        vector<array<LL, 20> > sum(n + 1);
        vector<bool> v(n + 1), oncycle(n + 1);
        vector<vector<int> > pos(n + 1);
        vector<int> dep(n + 1);
        vector<int> last(n + 1);
        for(int i = 1; i <= n; i++){
            if (v[i]) continue;
            int t = i;
            while(!v[t]){
                v[t] = true;
                t = a[t];
            }
            vector<int> cycle{t};
            while(a[cycle.back()] != t){
                cycle.push_back(a[cycle.back()]);
            }
            for(auto x : cycle) oncycle[x] = true;
            const int m = cycle.size();
            set<int> s;
            for(int j = 0; j < m; j++){
                pos[c[cycle[j]]].push_back(j);
                s.insert(c[cycle[j]]);
            }
            for(int j = 0; j < m; j++){
                pos[c[cycle[j]]].push_back(j + m);
            }
            for(auto x : s){
                reverse(pos[x].begin(), pos[x].end());
            }

            vector<int> pt;

            for(int j = 0; j < m; j++){
                auto dfs = [&](auto &&dfs, int u, int fa) -> void {
                    pt.push_back(u);
                    v[u] = true;

                    if (last[c[u]]){
                        f[u][0] = last[c[u]];
                        sum[u][0] = dep[u] - dep[last[c[u]]];
                    }
                    else{
                        while(!pos[c[u]].empty() and pos[c[u]].back() <= j){
                            pos[c[u]].pop_back();
                        }
                        if (!pos[c[u]].empty()){
                            f[u][0] = cycle[pos[c[u]].back() % m];
                            sum[u][0] = dep[u] + pos[c[u]].back() - j;
                        }
                        else{
                            f[u][0] = 0;
                            sum[u][0] = 1e18;
                        }
                    }

                    // LL ss = 0;
                    // for(int i = 1; i <= 19; i++){
                    //     f[u][i] = f[f[u][i - 1]][i - 1];
                    //     sum[u][i] = sum[u][i - 1] + sum[f[u][i - 1]][i - 1];
                    // }
                    // {
                    //     int p = u;
                    //     for(int i = 0; i <= 19; i++){
                    //         if (k >> i & 1){
                    //             ss += sum[p][i];
                    //             p = f[p][i];
                    //         }
                    //     }
                    //     tim[u] = ss;
                    //     ans[u] = c[u];
                    // }

                    int bk = last[c[u]];
                    last[c[u]] = u;
                    
                    for(auto j : g[u]){
                        if (j == fa or oncycle[j]) continue;
                        dep[j] = dep[u] + 1;
                        ans[j] = ans[u];
                        tim[j] = tim[u] + 1;
                        dfs(dfs, j, u);
                    }
                    last[c[u]] = bk;
                };

                dfs(dfs, cycle[j], -1);
            }

            for(int i = 1; i <= 19; i++){
                for(auto u : pt){
                    f[u][i] = f[f[u][i - 1]][i - 1];
                    sum[u][i] = sum[u][i - 1] + sum[f[u][i - 1]][i - 1];
                }
            }

            for(auto u : pt){
                LL ss = 0;
                int p = u;
                for(int i = 0; i <= 19; i++){
                    if (k >> i & 1){
                        ss += sum[p][i];
                        p = f[p][i];
                    }
                }
                tim[u] = ss;
                ans[u] = c[p];
            }

            for(int i = 2 * m - 2; i >= 0; i--){
                int x = cycle[i % m], y = cycle[(i + 1) % m];
                if (tim[y] + 1 < tim[x]){
                    tim[x] = tim[y] + 1;
                    ans[x] = ans[y];
                }
            }

            for(int j = 0; j < m; j++){
                pos[a[cycle[j]]].clear();

                auto dfs = [&](auto &&dfs, int u, int fa) -> void {
                    for(auto j : g[u]){
                        if (j == fa or oncycle[j]) continue;
                        if (tim[u] + 1 < tim[j]){
                            tim[j] = tim[u] + 1;
                            ans[j] = ans[u];
                        }
                        dfs(dfs, j, u);
                    }
                };

                dfs(dfs, cycle[j], -1);
            }
            // for(auto x : cycle){
            //     cout << x << ' ';
            // }
            // cout << '\n';
        }
        // for(int i = 1; i <= n; i++){
        //     cout << ans[i] << ' ' << tim[i] << '\n';
        //     // cout << f[i][0] << ' ' << sum[i][0] << '\n';
        // }
        LL res = 0;
        for(int i = 1; i <= n; i++){
            res += 1LL * i * ans[i];
        }
        cout << res << '\n';
    }

}

詳細信息

Test #1:

score: 100
Accepted
time: 1ms
memory: 3712kb

input:

3
5 3
2 2 1 3 3
2 5 1 2 4
5 4
2 2 1 3 3
2 5 1 2 4
3 10
1 2 3
1 3 2

output:

41
45
14

result:

ok 3 lines

Test #2:

score: -100
Wrong Answer
time: 43ms
memory: 3712kb

input:

6000
19 48
18 19 18 19 11 9 15 19 12 18 11 18 9 18 9 18 19 11 15
12 14 18 8 1 3 19 5 13 14 15 2 14 5 19 2 19 12 9
15 23
3 1 1 3 6 1 4 1 1 6 6 4 12 4 6
14 1 8 8 6 6 12 14 6 8 5 7 14 2 5
9 140979583
4 5 8 9 2 7 6 8 2
8 9 4 6 9 2 4 7 8
4 976357580
2 3 1 3
2 1 1 4
6 508962809
4 3 4 3 4 4
4 5 4 5 5 6
13 ...

output:

3420
260
254
26
84
759
126
30
1092
1
2493
2422
168
360
298
324
2480
2520
220
228
1107
9
3486
1
796
81
340
272
600
3196
32
495
40
128
140
665
1635
702
68
96
90
288
29
588
16
234
435
2928
140
40
477
1197
19
1994
1082
32
900
672
20
390
32
2204
1907
42
21
926
4
1539
196
420
11
1709
801
720
1
556
40
17
2...

result:

wrong answer 17th lines differ - expected: '2424', found: '2480'