QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#854465#9732. Gathering Mushroomsucup-team1134#WA 128ms9768kbC++2316.8kb2025-01-12 01:38:552025-01-12 01:38:56

Judging History

你现在查看的是最新测评结果

  • [2025-01-12 01:38:56]
  • 评测
  • 测评结果:WA
  • 用时:128ms
  • 内存:9768kb
  • [2025-01-12 01:38:55]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template<class T>bool chmax(T &a, const T &b) { if (a<b) { a=b; return true; } return false; }
template<class T>bool chmin(T &a, const T &b) { if (b<a) { a=b; return true; } return false; }
#define vi vector<int>
#define vl vector<ll>
#define vii vector<pair<int,int>>
#define vll vector<pair<ll,ll>>
#define vvi vector<vector<int>>
#define vvl vector<vector<ll>>
#define vvii vector<vector<pair<int,int>>>
#define vvll vector<vector<pair<ll,ll>>>
#define vst vector<string>
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define mkunique(x) sort(all(x));(x).erase(unique(all(x)),(x).end())
#define fi first
#define se second
#define mp make_pair
#define si(x) int(x.size())
const int mod=998244353,MAX=200005,INF=15<<26;

//Wavelet matrix

// https://kopricky.github.io/code/DataStructure_Advanced/wavelet_matrix.html

struct BitRank {
    // block: bit 列を管理, count: block ごとに立っている 1 の数を管理
    vector<unsigned long long> block;
    vector<unsigned int> count;
    BitRank(){}
    void resize(const unsigned int num){
        block.resize(((num + 1) >> 6) + 1, 0);
        count.resize(block.size(), 0);
    }
    // i ビット目を val(0,1) にセット
    void set(const unsigned int i, const unsigned long long val){
        block[i >> 6] |= (val << (i & 63));
    }
    void build(){
        for(unsigned int i = 1; i < block.size(); i++){
            count[i] = count[i - 1] + __builtin_popcountll(block[i - 1]);
        }
    }
    // [0, i) ビットの 1 の数
    unsigned int rank1(const unsigned int i) const {
        return count[i >> 6] +
        __builtin_popcountll(block[i >> 6] & ((1ULL << (i & 63)) - 1ULL));
    }
    // [i, j) ビットの 1 の数
    unsigned int rank1(const unsigned int i, const unsigned int j) const {
        return rank1(j) - rank1(i);
    }
    // [0, i) ビットの 0 の数
    unsigned int rank0(const unsigned int i) const {
        return i - rank1(i);
    }
    // [i, j) ビットの 0 の数
    unsigned int rank0(const unsigned int i, const unsigned int j) const {
        return rank0(j) - rank0(i);
    }
};

class WaveletMatrix
{
private:
    unsigned int height;
    vector<BitRank> B;
    vector<int> pos;
public:
    WaveletMatrix(){}
    WaveletMatrix(vector<int> vec) :
    WaveletMatrix(vec, *max_element(vec.begin(), vec.end()) + 1){}
    // sigma:文字の種類数
    WaveletMatrix(vector<int> vec, const unsigned int sigma){
        init(vec, sigma);
    }
    void init(vector<int>& vec, const unsigned int sigma){
        height = (sigma == 1) ? 1 : (64 - __builtin_clzll(sigma - 1));
        B.resize(height), pos.resize(height);
        for(unsigned int i = 0; i < height; ++i){
            B[i].resize(vec.size());
            for(unsigned int j = 0; j < vec.size(); ++j) {
                B[i].set(j, get(vec[j], height - i - 1));
            }
            B[i].build();
            auto it = stable_partition(vec.begin(), vec.end(), [&](int c){
                return !get(c, height - i - 1);
            });
            pos[i] = it - vec.begin();
        }
    }
    // val の i ビット目の値を返す(0,1)
    int get(const int val, const int i){
        return val >> i & 1;
    }
    // [l, r) の間に現れる値 val の数
    int rank(const int val, const int l, const int r){
        return rank(val, r) - rank(val, l);
    }
    // [0, i) の間に現れる値 val の数
    int rank(int val, int i){
        int p = 0;
        for(unsigned int j = 0; j < height; ++j){
            if(get(val, height - j - 1)){
                p = pos[j] + B[j].rank1(p);
                i = pos[j] + B[j].rank1(i);
            }else{
                p = B[j].rank0(p);
                i = B[j].rank0(i);
            }
        }
        return i - p;
    }
    // [l, r) の k(0,1,2...) 番目に小さい値を返す
    int quantile(int k, int l, int r){
        int res = 0;
        for(unsigned int i = 0; i < height; ++i){
            const int j = B[i].rank0(l, r);
            if(j > k){
                l = B[i].rank0(l);
                r = B[i].rank0(r);
            }else{
                l = pos[i] + B[i].rank1(l);
                r = pos[i] + B[i].rank1(r);
                k -= j;
                res |= (1 << (height - i - 1));
            }
        }
        return res;
    }
    int rangefreq(const int i, const int j, const int a, const int b,
                  const int l, const int r, const int x){
        if(i == j || r <= a || b <= l) return 0;
        const int mid = (l + r) >> 1;
        if(a <= l && r <= b){
            return j - i;
        }else{
            const int left = rangefreq(B[x].rank0(i), B[x].rank0(j), a, b, l, mid, x + 1);
            const int right = rangefreq(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j),
                                        a, b, mid, r, x + 1);
            return left + right;
        }
    }
    // [l,r) で値が [a,b) 内に含まれる数を返す
    int rangefreq(const int l, const int r, const int a, const int b){
        return rangefreq(l, r, a, b, 0, 1 << height, 0);
    }
    int rangemin(const int i, const int j, const int a, const int b,
                 const int l, const int r, const int x, const int val){
        if(i == j || r <= a || b <= l) return -1;
        if(r - l == 1) return val;
        const int mid = (l + r) >> 1;
        const int res = rangemin(B[x].rank0(i),B[x].rank0(j),a,b,l,mid,x+1,val);
        if(res < 0) return rangemin(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j),
                                    a, b, mid, r, x + 1, val + (1 << (height - x - 1)));
        else return res;
    }
    // [l,r) で値が [a,b) 内に最小の数を返す(数が存在しない場合は -1 を返す)
    int rangemin(int l, int r, int a, int b){
        return rangemin(l, r, a, b, 0, 1 << height, 0, 0);
    }
};

template<typename T> class OrthogonalRangeCount
{
private:
    using ptt = pair<T, T>;
    const int sz;
    vector<T> X, Y;
    WaveletMatrix wm;
    
public:
    OrthogonalRangeCount(vector<ptt> candidate)
    : sz((int)candidate.size()), X(sz), Y(sz){
        sort(candidate.begin(), candidate.end());
        vector<int> vec(sz);
        for(int i = 0; i < sz; ++i){
            X[i] = candidate[i].first, Y[i] = candidate[i].second;
        }
        sort(Y.begin(), Y.end());
        Y.erase(unique(Y.begin(), Y.end()), Y.end());
        for(int i = 0; i < sz; ++i){
            vec[i] = lower_bound(Y.begin(), Y.end(), candidate[i].second) - Y.begin();
        }
        wm.init(vec, Y.size());
    }
    // [lx,rx) × [ly, ry) の長方形領域に含まれる点の数を答える
    int query(const T lx, const T ly, const T rx, const T ry){
        const int lxid = lower_bound(X.begin(), X.end(), lx) - X.begin();
        const int rxid = lower_bound(X.begin(), X.end(), rx) - X.begin();
        const int lyid = lower_bound(Y.begin(), Y.end(), ly) - Y.begin();
        const int ryid = lower_bound(Y.begin(), Y.end(), ry) - Y.begin();
        if(lxid >= rxid || lyid >= ryid) return 0;
        return wm.rangefreq(lxid, rxid, lyid, ryid);
    }
};

vector<int> G[MAX];
ll sz[MAX];

struct UF{
    int n;
    vector<int> par,size,edge;
    
    void init(int n_){
        n=n_;
        par.assign(n,-1);
        size.assign(n,1);
        edge.assign(n,0);
        
        for(int i=0;i<n;i++){
            par[i]=i;
        }
    }
    
    int root(int a){
        if(par[a]==a) return a;
        else return par[a]=root(par[a]);
    }
    
    void unite(int a,int b){
        edge[root(a)]++;
        if(root(a)!=root(b)){
            size[root(a)]+=size[root(b)];
            edge[root(a)]+=edge[root(b)];
            par[root(b)]=root(a);
        }
    }
    
    bool check(int a,int b){
        return root(a)==root(b);
    }
};

vector<int> cyc[MAX];
pll ans[MAX];
int ty[MAX];
int depth[MAX];
vi retu[MAX];
int cycpos[MAX];
ll K;
int roo[MAX];

void DFS(int u,int ro){
    roo[u]=ro;
    if(depth[u]){
        retu[ty[u]].pb(depth[u]);
        if(si(retu[ty[u]])>=K){
            ans[u].fi=depth[u]-retu[ty[u]][si(retu[ty[u]])-K]+1;
            ans[u].se=ty[u];
        }else{
            ans[u].fi=si(retu[ty[u]]);
            ans[u].se=-1;
        }
    }
    for(int to:G[u]){
        depth[to]=depth[u]+1;
        DFS(to,ro);
    }
    if(depth[u]){
        retu[ty[u]].pop_back();
    }
}

void solve(int a){
    for(int b:G[a]){
        if(ans[b].se==-1){
            ans[b]=mp(ans[a].fi+1,ans[a].se);
        }else if(ans[a].se==ans[b].se){
            
        }else if(ans[a].fi+1<ans[b].fi){
            ans[b]=mp(ans[a].fi+1,ans[a].se);
        }
        solve(b);
    }
}

UF uf;
vector<WaveletMatrix> WM;

void solveX(int ii){
    for(int i:G[ii]){
        ll sude=depth[i],rem=K-ans[i].fi;
        //cout<<i<<" "<<rem<<" ! "<<endl;
        int ro=roo[i];
        int M=si(cyc[uf.root(ro)]);
        //cout<<ro<<" "<<ty[i]<<" "<<M<<endl;
        ll al=WM[uf.root(ro)].rank(ty[i],0,M);
        
        if(al==0){
            ans[i]=mp(-1,-1);
            solveX(i);
            continue;
        }
        
        ll l=max(0LL,rem/al-1)*M,r=(rem/al+2)*M;
        
        while(r-l>1){
            ll m=(l+r)/2;
            ll can=0;
            if(cycpos[ro]+m>=M){
                can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],M);
                can+=al*((m-(M-cycpos[ro]))/M);
                can+=WM[uf.root(ro)].rank(ty[i],0,(m-(M-cycpos[ro]))%M);
            }else{
                can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],cycpos[ro]+m);
            }
            
            if(can>=rem) r=m;
            else l=m;
        }
        
        ans[i]=mp(sude+r,ty[i]);
        solveX(i);
    }
}


int main(){
    
    std::ifstream in("text.txt");
    std::cin.rdbuf(in.rdbuf());
    cin.tie(0);
    ios::sync_with_stdio(false);
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    int Q;cin>>Q;
    while(Q--){
        ll N;cin>>N;
        cin>>K;
        for(int i=0;i<N;i++){
            G[i].clear();
            cyc[i].clear();
            sz[i]=0;
            ans[i]=mp(0,-1);
            ty[i]=0;
            depth[i]=0;
            retu[i].clear();
            cycpos[i]=0;
            roo[i]=0;
        }
        vvi color(N);
        for(int i=0;i<N;i++){
            cin>>ty[i];ty[i]--;
            //ty[i]=i;
            if(ty[i]<0) ty[i]+=N;
            color[ty[i]].pb(i);
        }
        vector<int> A(N),deg(N),seen(N);
        for(int i=0;i<N;i++){
            int x;cin>>x;x--;
            //x=(i+1)%N;
            A[i]=x;
            deg[x]++;
        }
        queue<int> Q;
        for(int i=0;i<N;i++){
            if(deg[i]==0) Q.push(i);
        }
        
        while(!Q.empty()){
            int u=Q.front();Q.pop();
            deg[A[u]]--;
            if(deg[A[u]]==0) Q.push(A[u]);
        }
        
        uf.init(N);
        for(int i=0;i<N;i++){
            if(deg[i]) uf.unite(i,A[i]);
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]==0) G[A[i]].push_back(i);
        }
        
        vl CN(N);
        
        WM=vector<WaveletMatrix>(N);
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                if(uf.root(i)==i){
                    int now=i;
                    vi ss;
                    while(1){
                        if(!seen[now]){
                            seen[now]=true;
                            cycpos[now]=si(cyc[i]);
                            cyc[i].push_back(now);
                            ss.pb(ty[now]);
                            now=A[now];
                        }else{
                            break;
                        }
                    }
                    //WM[i]=WaveletMatrix(ss);
                    WM[i].init(ss,N);
                    ll len=si(cyc[i]);
                     ll l=0,r=N*K;
                     ll xx=-1;
                     while(r-l>1){
                     ll m=(l+r)/2;
                     
                     for(int j=0;j<len;j++){
                     if(j<m%len) CN[ty[cyc[i][j]]]++;
                     CN[ty[cyc[i][j]]]+=m/len;
                     }
                     
                     bool ok=false;
                     for(int j=0;j<len;j++){
                     if(CN[ty[cyc[i][j]]]>=K){
                     ok=true;
                     xx=ty[cyc[i][j]];
                     }
                     CN[ty[cyc[i][j]]]=0;
                     }
                     
                     if(ok) r=m;
                     else l=m;
                     }
                    
                    //cout<<i<<" "<<r<<" "<<xx<<endl;
                     
                     ans[i]=mp(r,xx);
                }
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                DFS(i,i);
            }
        }
        /*
         for(int a=0;a<N;a++){
         for(int l=0;l<3;l++){
         for(int r=l;r<3;r++){
         cout<<WM[3].rank(a,l,r)<<" ";
         }
         }
         cout<<endl;
         }
         */
        
        for(int ii=0;ii<N;ii++){
            if(deg[ii]&&uf.root(ii)==ii){
                ll la=ans[ii].fi;
                //la=15LL<<58;
                for(int j=si(cyc[ii])-1;j>=1;j--){
                    int i=cyc[ii][j];
                    
                    ll sude=depth[i],rem=K-ans[i].fi;
                    //cout<<i<<" "<<rem<<" ! "<<endl;
                    int ro=roo[i];
                    int M=si(cyc[uf.root(ro)]);
                    //cout<<ro<<" "<<ty[i]<<" "<<M<<endl;
                    ll al=WM[uf.root(ro)].rank(ty[i],0,M);
                    //cout<<i<<" "<<" " <<al<<endl;
                    if(al==0){
                        ans[i]=mp(-1,-1);
                        la++;
                        continue;
                    }
                    
                    ll l=max(0LL,rem/al-1)*M,r=(rem/al+2)*M;
                    if(l>la+1){
                        ans[i]=mp(-1,-1);
                        la++;
                        continue;
                    }
                    chmin(r,la+2);
                    
                    while(r-l>1){
                        ll m=(l+r)/2;
                        ll can=0;
                        if(cycpos[ro]+m>=M){
                            can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],M);
                            can+=al*((m-(M-cycpos[ro]))/M);
                            can+=WM[uf.root(ro)].rank(ty[i],0,(m-(M-cycpos[ro]))%M);
                        }else{
                            can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],cycpos[ro]+m);
                        }
                        
                        if(can>=rem) r=m;
                        else l=m;
                    }
                    //cout<<i<<endl;
                    if(r==la+2){
                        ans[i]=mp(-1,-1);
                        la++;
                        continue;
                    }else{
                        ans[i]=mp(sude+r,ty[i]);
                        la=ans[i].fi;
                    }
                }
            }
        }
        
        //for(int i=0;i<N;i++) cout<<ans[i].fi<<" "<<ans[i].se<<endl;
        for(int i=0;i<N;i++){
            if(deg[i]){
                solveX(i);
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]&&uf.root(i)==i){
                for(int j=si(cyc[i])-1;j>=0;j--){
                    int a=cyc[i][(j+1+si(cyc[i]))%si(cyc[i])],b=cyc[i][j];
                    if(ans[b].se==-1){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }else if(ans[a].se==ans[b].se){
                        
                    }else if(ans[a].fi+1<ans[b].fi){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }
                }
                for(int j=si(cyc[i])-1;j>=0;j--){
                    int a=cyc[i][(j+1+si(cyc[i]))%si(cyc[i])],b=cyc[i][j];
                    if(ans[b].se==-1){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }else if(ans[a].se==ans[b].se){
                        
                    }else if(ans[a].fi+1<ans[b].fi){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }
                }
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                solve(i);
            }
        }
        
        ll res=0;
        for(int i=0;i<N;i++){
            res+=(ll)(i+1)*(ans[i].se+1);
        }
        cout<<res<<"\n";
    }
    
}

詳細信息

Test #1:

score: 100
Accepted
time: 0ms
memory: 9768kb

input:

3
5 3
2 2 1 3 3
2 5 1 2 4
5 4
2 2 1 3 3
2 5 1 2 4
3 10
1 2 3
1 3 2

output:

41
45
14

result:

ok 3 lines

Test #2:

score: -100
Wrong Answer
time: 128ms
memory: 8160kb

input:

6000
19 48
18 19 18 19 11 9 15 19 12 18 11 18 9 18 9 18 19 11 15
12 14 18 8 1 3 19 5 13 14 15 2 14 5 19 2 19 12 9
15 23
3 1 1 3 6 1 4 1 1 6 6 4 12 4 6
14 1 8 8 6 6 12 14 6 8 5 7 14 2 5
9 140979583
4 5 8 9 2 7 6 8 2
8 9 4 6 9 2 4 7 8
4 976357580
2 3 1 3
2 1 1 4
6 508962809
4 3 4 3 4 4
4 5 4 5 5 6
13 ...

output:

3420
260
254
26
84
759
126
30
1092
0
2493
2422
168
360
298
324
2424
2520
220
228
1107
9
3486
0
796
81
340
272
600
3196
32
495
40
128
140
665
1635
702
68
96
90
288
29
588
16
234
445
2928
140
40
477
1197
19
1994
1082
32
522
672
20
390
32
2204
1938
42
21
885
4
1539
196
420
11
1709
801
720
0
556
40
17
2...

result:

wrong answer 10th lines differ - expected: '1', found: '0'