QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#854469#9732. Gathering Mushroomsucup-team1134#WA 0ms11856kbC++2314.6kb2025-01-12 01:41:312025-01-12 01:41:32

Judging History

你现在查看的是最新测评结果

  • [2025-01-12 01:41:32]
  • 评测
  • 测评结果:WA
  • 用时:0ms
  • 内存:11856kb
  • [2025-01-12 01:41:31]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template<class T>bool chmax(T &a, const T &b) { if (a<b) { a=b; return true; } return false; }
template<class T>bool chmin(T &a, const T &b) { if (b<a) { a=b; return true; } return false; }
#define vi vector<int>
#define vl vector<ll>
#define vii vector<pair<int,int>>
#define vll vector<pair<ll,ll>>
#define vvi vector<vector<int>>
#define vvl vector<vector<ll>>
#define vvii vector<vector<pair<int,int>>>
#define vvll vector<vector<pair<ll,ll>>>
#define vst vector<string>
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define mkunique(x) sort(all(x));(x).erase(unique(all(x)),(x).end())
#define fi first
#define se second
#define mp make_pair
#define si(x) int(x.size())
const int mod=998244353,MAX=200005,INF=15<<26;

//Wavelet matrix

// https://kopricky.github.io/code/DataStructure_Advanced/wavelet_matrix.html

struct BitRank {
    // block: bit 列を管理, count: block ごとに立っている 1 の数を管理
    vector<unsigned long long> block;
    vector<unsigned int> count;
    BitRank(){}
    void resize(const unsigned int num){
        block.resize(((num + 1) >> 6) + 1, 0);
        count.resize(block.size(), 0);
    }
    // i ビット目を val(0,1) にセット
    void set(const unsigned int i, const unsigned long long val){
        block[i >> 6] |= (val << (i & 63));
    }
    void build(){
        for(unsigned int i = 1; i < block.size(); i++){
            count[i] = count[i - 1] + __builtin_popcountll(block[i - 1]);
        }
    }
    // [0, i) ビットの 1 の数
    unsigned int rank1(const unsigned int i) const {
        return count[i >> 6] +
        __builtin_popcountll(block[i >> 6] & ((1ULL << (i & 63)) - 1ULL));
    }
    // [i, j) ビットの 1 の数
    unsigned int rank1(const unsigned int i, const unsigned int j) const {
        return rank1(j) - rank1(i);
    }
    // [0, i) ビットの 0 の数
    unsigned int rank0(const unsigned int i) const {
        return i - rank1(i);
    }
    // [i, j) ビットの 0 の数
    unsigned int rank0(const unsigned int i, const unsigned int j) const {
        return rank0(j) - rank0(i);
    }
};

class WaveletMatrix
{
private:
    unsigned int height;
    vector<BitRank> B;
    vector<int> pos;
public:
    WaveletMatrix(){}
    WaveletMatrix(vector<int> vec) :
    WaveletMatrix(vec, *max_element(vec.begin(), vec.end()) + 1){}
    // sigma:文字の種類数
    WaveletMatrix(vector<int> vec, const unsigned int sigma){
        init(vec, sigma);
    }
    void init(vector<int>& vec, const unsigned int sigma){
        height = (sigma == 1) ? 1 : (64 - __builtin_clzll(sigma - 1));
        B.resize(height), pos.resize(height);
        for(unsigned int i = 0; i < height; ++i){
            B[i].resize(vec.size());
            for(unsigned int j = 0; j < vec.size(); ++j) {
                B[i].set(j, get(vec[j], height - i - 1));
            }
            B[i].build();
            auto it = stable_partition(vec.begin(), vec.end(), [&](int c){
                return !get(c, height - i - 1);
            });
            pos[i] = it - vec.begin();
        }
    }
    // val の i ビット目の値を返す(0,1)
    int get(const int val, const int i){
        return val >> i & 1;
    }
    // [l, r) の間に現れる値 val の数
    int rank(const int val, const int l, const int r){
        return rank(val, r) - rank(val, l);
    }
    // [0, i) の間に現れる値 val の数
    int rank(int val, int i){
        int p = 0;
        for(unsigned int j = 0; j < height; ++j){
            if(get(val, height - j - 1)){
                p = pos[j] + B[j].rank1(p);
                i = pos[j] + B[j].rank1(i);
            }else{
                p = B[j].rank0(p);
                i = B[j].rank0(i);
            }
        }
        return i - p;
    }
    // [l, r) の k(0,1,2...) 番目に小さい値を返す
    int quantile(int k, int l, int r){
        int res = 0;
        for(unsigned int i = 0; i < height; ++i){
            const int j = B[i].rank0(l, r);
            if(j > k){
                l = B[i].rank0(l);
                r = B[i].rank0(r);
            }else{
                l = pos[i] + B[i].rank1(l);
                r = pos[i] + B[i].rank1(r);
                k -= j;
                res |= (1 << (height - i - 1));
            }
        }
        return res;
    }
    int rangefreq(const int i, const int j, const int a, const int b,
                  const int l, const int r, const int x){
        if(i == j || r <= a || b <= l) return 0;
        const int mid = (l + r) >> 1;
        if(a <= l && r <= b){
            return j - i;
        }else{
            const int left = rangefreq(B[x].rank0(i), B[x].rank0(j), a, b, l, mid, x + 1);
            const int right = rangefreq(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j),
                                        a, b, mid, r, x + 1);
            return left + right;
        }
    }
    // [l,r) で値が [a,b) 内に含まれる数を返す
    int rangefreq(const int l, const int r, const int a, const int b){
        return rangefreq(l, r, a, b, 0, 1 << height, 0);
    }
    int rangemin(const int i, const int j, const int a, const int b,
                 const int l, const int r, const int x, const int val){
        if(i == j || r <= a || b <= l) return -1;
        if(r - l == 1) return val;
        const int mid = (l + r) >> 1;
        const int res = rangemin(B[x].rank0(i),B[x].rank0(j),a,b,l,mid,x+1,val);
        if(res < 0) return rangemin(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j),
                                    a, b, mid, r, x + 1, val + (1 << (height - x - 1)));
        else return res;
    }
    // [l,r) で値が [a,b) 内に最小の数を返す(数が存在しない場合は -1 を返す)
    int rangemin(int l, int r, int a, int b){
        return rangemin(l, r, a, b, 0, 1 << height, 0, 0);
    }
};

template<typename T> class OrthogonalRangeCount
{
private:
    using ptt = pair<T, T>;
    const int sz;
    vector<T> X, Y;
    WaveletMatrix wm;
    
public:
    OrthogonalRangeCount(vector<ptt> candidate)
    : sz((int)candidate.size()), X(sz), Y(sz){
        sort(candidate.begin(), candidate.end());
        vector<int> vec(sz);
        for(int i = 0; i < sz; ++i){
            X[i] = candidate[i].first, Y[i] = candidate[i].second;
        }
        sort(Y.begin(), Y.end());
        Y.erase(unique(Y.begin(), Y.end()), Y.end());
        for(int i = 0; i < sz; ++i){
            vec[i] = lower_bound(Y.begin(), Y.end(), candidate[i].second) - Y.begin();
        }
        wm.init(vec, Y.size());
    }
    // [lx,rx) × [ly, ry) の長方形領域に含まれる点の数を答える
    int query(const T lx, const T ly, const T rx, const T ry){
        const int lxid = lower_bound(X.begin(), X.end(), lx) - X.begin();
        const int rxid = lower_bound(X.begin(), X.end(), rx) - X.begin();
        const int lyid = lower_bound(Y.begin(), Y.end(), ly) - Y.begin();
        const int ryid = lower_bound(Y.begin(), Y.end(), ry) - Y.begin();
        if(lxid >= rxid || lyid >= ryid) return 0;
        return wm.rangefreq(lxid, rxid, lyid, ryid);
    }
};

vector<int> G[MAX];
ll sz[MAX];

struct UF{
    int n;
    vector<int> par,size,edge;
    
    void init(int n_){
        n=n_;
        par.assign(n,-1);
        size.assign(n,1);
        edge.assign(n,0);
        
        for(int i=0;i<n;i++){
            par[i]=i;
        }
    }
    
    int root(int a){
        if(par[a]==a) return a;
        else return par[a]=root(par[a]);
    }
    
    void unite(int a,int b){
        edge[root(a)]++;
        if(root(a)!=root(b)){
            size[root(a)]+=size[root(b)];
            edge[root(a)]+=edge[root(b)];
            par[root(b)]=root(a);
        }
    }
    
    bool check(int a,int b){
        return root(a)==root(b);
    }
};

vector<int> cyc[MAX];
pll ans[MAX];
int ty[MAX];
int depth[MAX];
vi retu[MAX];
int cycpos[MAX];
ll K;
int roo[MAX];

void DFS(int u,int ro){
    roo[u]=ro;
    if(depth[u]){
        retu[ty[u]].pb(depth[u]);
        if(si(retu[ty[u]])>=K){
            ans[u].fi=depth[u]-retu[ty[u]][si(retu[ty[u]])-K]+1;
            ans[u].se=ty[u];
        }else{
            ans[u].fi=si(retu[ty[u]]);
            ans[u].se=-1;
        }
    }
    for(int to:G[u]){
        depth[to]=depth[u]+1;
        DFS(to,ro);
    }
    if(depth[u]){
        retu[ty[u]].pop_back();
    }
}

void solve(int a){
    for(int b:G[a]){
        if(ans[b].se==-1){
            ans[b]=mp(ans[a].fi+1,ans[a].se);
        }else if(ans[a].se==ans[b].se){
            
        }else if(ans[a].fi+1<ans[b].fi){
            ans[b]=mp(ans[a].fi+1,ans[a].se);
        }
        solve(b);
    }
}


int main(){
    
    std::ifstream in("text.txt");
    std::cin.rdbuf(in.rdbuf());
    cin.tie(0);
    ios::sync_with_stdio(false);
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    int Q;cin>>Q;
    while(Q--){
        ll N;cin>>N;
        cin>>K;
        for(int i=0;i<N;i++){
            G[i].clear();
            cyc[i].clear();
            sz[i]=0;
            ans[i]=mp(0,-1);
            ty[i]=0;
            depth[i]=0;
            retu[i].clear();
            cycpos[i]=0;
            roo[i]=0;
        }
        vvi color(N);
        for(int i=0;i<N;i++){
            cin>>ty[i];ty[i]--;
            //ty[i]=i;
            if(ty[i]<0) ty[i]+=N;
            color[ty[i]].pb(i);
        }
        vector<int> A(N),deg(N),seen(N);
        for(int i=0;i<N;i++){
            int x;cin>>x;x--;
            //x=(i+1)%N;
            A[i]=x;
            deg[x]++;
        }
        queue<int> Q;
        for(int i=0;i<N;i++){
            if(deg[i]==0) Q.push(i);
        }
        
        while(!Q.empty()){
            int u=Q.front();Q.pop();
            deg[A[u]]--;
            if(deg[A[u]]==0) Q.push(A[u]);
        }
        
        UF uf;uf.init(N);
        for(int i=0;i<N;i++){
            if(deg[i]) uf.unite(i,A[i]);
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]==0) G[A[i]].push_back(i);
        }
        
        vl CN(N);
        
        vector<WaveletMatrix> WM(N);
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                if(uf.root(i)==i){
                    int now=i;
                    vi ss;
                    while(1){
                        if(!seen[now]){
                            seen[now]=true;
                            cycpos[now]=si(cyc[i]);
                            cyc[i].push_back(now);
                            ss.pb(ty[now]);
                            now=A[now];
                        }else{
                            break;
                        }
                    }
                    //WM[i]=WaveletMatrix(ss);
                    WM[i].init(ss,N);
                    ll len=si(cyc[i]);
                    /*
                     ll l=0,r=N*K;
                     ll xx=-1;
                     while(r-l>1){
                     ll m=(l+r)/2;
                     
                     for(int j=0;j<len;j++){
                     if(j<m%len) CN[ty[cyc[i][j]]]++;
                     CN[ty[cyc[i][j]]]+=m/len;
                     }
                     
                     bool ok=false;
                     for(int j=0;j<len;j++){
                     if(CN[ty[cyc[i][j]]]>=K){
                     ok=true;
                     xx=ty[cyc[i][j]];
                     }
                     CN[ty[cyc[i][j]]]=0;
                     }
                     
                     if(ok) r=m;
                     else l=m;
                     }
                     
                     ans[i]=mp(r,xx);
                     */
                }
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                DFS(i,i);
            }
        }
        /*
         for(int a=0;a<N;a++){
         for(int l=0;l<3;l++){
         for(int r=l;r<3;r++){
         cout<<WM[3].rank(a,l,r)<<" ";
         }
         }
         cout<<endl;
         }
         */
        for(int i=0;i<N;i++){
            //if(deg[i]&&uf.root(i)==i) continue;
            if(ans[i].se==-1){
                ll sude=depth[i],rem=K-ans[i].fi;
                //cout<<i<<" "<<rem<<" ! "<<endl;
                int ro=roo[i];
                int M=si(cyc[uf.root(ro)]);
                //cout<<ro<<" "<<ty[i]<<" "<<M<<endl;
                ll al=WM[uf.root(ro)].rank(ty[i],0,M);
                
                if(al==0){
                    ans[i]=mp(-1,-1);
                    continue;
                }
                
                ll l=max(0LL,rem/al*M+rem%al-1),r=(rem+al-1)/al*al;
                
                while(r-l>1){
                    ll m=(l+r)/2;
                    ll can=0;
                    if(cycpos[ro]+m>=M){
                        can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],M);
                        can+=al*((m-(M-cycpos[ro]))/M);
                        can+=WM[uf.root(ro)].rank(ty[i],0,(m-(M-cycpos[ro]))%M);
                    }else{
                        can+=WM[uf.root(ro)].rank(ty[i],cycpos[ro],cycpos[ro]+m);
                    }
                    
                    if(can>=rem) r=m;
                    else l=m;
                }
                
                ans[i]=mp(sude+r,ty[i]);
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]&&uf.root(i)==i){
                for(int j=si(cyc[i])-1;j>=0;j--){
                    int a=cyc[i][(j+1+si(cyc[i]))%si(cyc[i])],b=cyc[i][j];
                    if(ans[a].se==ans[b].se){
                        
                    }else if(ans[a].fi+1<ans[b].fi){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }
                }
                for(int j=si(cyc[i])-1;j>=0;j--){
                    int a=cyc[i][(j+1+si(cyc[i]))%si(cyc[i])],b=cyc[i][j];
                    if(ans[a].se==ans[b].se){
                        
                    }else if(ans[a].fi+1<ans[b].fi){
                        ans[b]=mp(ans[a].fi+1,ans[a].se);
                    }
                }
            }
        }
        
        for(int i=0;i<N;i++){
            if(deg[i]){
                solve(i);
            }
        }
        
        ll res=0;
        for(int i=0;i<N;i++){
            res+=(ll)(i+1)*(ans[i].se+1);
        }
        cout<<res<<"\n";
    }
    
}

詳細信息

Test #1:

score: 0
Wrong Answer
time: 0ms
memory: 11856kb

input:

3
5 3
2 2 1 3 3
2 5 1 2 4
5 4
2 2 1 3 3
2 5 1 2 4
3 10
1 2 3
1 3 2

output:

39
39
14

result:

wrong answer 1st lines differ - expected: '41', found: '39'