QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#697259#3840. Pass the Ball!yzj123#RE 2ms13860kbC++206.3kb2024-11-01 12:32:482024-11-01 12:32:48

Judging History

This is the latest submission verdict.

  • [2024-11-01 12:32:48]
  • Judged
  • Verdict: RE
  • Time: 2ms
  • Memory: 13860kb
  • [2024-11-01 12:32:48]
  • Submitted

answer

#include<bits/stdc++.h>
using i64 = long long;
using u64 = unsigned long long;
#define int i64
using Poly = std::vector<int>;
//4179340454199820289
const int G = 3, mod = 998244353, Maxn = 2e6 + 10;
int qmi(int a, int b = mod - 2) {
	int res = 1;
	while (b) {
		if (b & 1) {
			res = 1ll * res * a  % mod;
		}
		a = 1ll * a * a % mod;
		b >>= 1;
	}
	return res;
}
const int invG = qmi(G);
int tr[Maxn << 1], tf;
void tpre(int n) {
	if (tf == n) return ;
	tf = n;
	for (int i = 0; i < n; i ++) {
		tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);
	}
}

void NTT(int n, int *g, bool op) {
	tpre(n);
	static u64 f[Maxn << 1], w[Maxn << 1];
	w[0] = 1;
	for (int i = 0; i < n; i ++) {
		f[i] = (((i64)mod << 5) + g[tr[i]]) % mod;
	}
	for (int l = 1; l < n; l <<= 1) {
    	u64 tG = qmi(op ? G : invG, (mod - 1) / (l + l));
    	for (int i = 1; i < l; i ++) w[i] = w[i - 1] * tG % mod;
    	for (int k = 0; k < n; k += l + l)
      		for (int p = 0; p < l; p ++) {
        		int tt = w[p] * f[k | l | p] % mod;
        		f[k | l | p] = f[k | p] + mod - tt;
        		f[k | p] += tt;
      		}      
    	if (l == (1 << 10))
      		for (int i = 0; i < n; i ++) f[i] %= mod;
  	}
	if (! op) {
    	u64 invn = qmi(n);
    	for(int i = 0; i < n; ++ i) {
    		g[i] = f[i] % mod * invn % mod;
    	}
  	} else {
  		for (int i = 0; i < n; ++ i) {
  			g[i] = f[i] % mod;
  		}
  	}
}

void px(int n, int *f, int *g) {
	for (int i = 0; i < n; ++ i) {
		f[i] = 1ll * f[i] * g[i] % mod;
	}
}

Poly operator +(const Poly &A, const Poly &B) {
	Poly C = A;
	C.resize(std::max(A.size(), B.size()));
	for (int i = 0; i < B.size(); i ++) {
		C[i] = (C[i] + B[i]) % mod;
	}
	return C;
}
Poly operator -(const Poly &A, const Poly &B) {
	Poly C = A;
	C.resize(std::max(A.size(),B.size()));
	for (int i = 0; i < B.size(); i ++) {
		C[i] = (C[i] + mod - B[i]) % mod;
	}
	return C;
}
Poly operator *(const int c, const Poly &A) {
	Poly C;
	C.resize(A.size());
	for (int i = 0; i < A.size(); i ++) {
		C[i] = 1ll * c * A[i] % mod;
	}
	return C;
}
int lim; // set.
Poly operator *(const Poly &A, const Poly &B) {
	static int a[Maxn << 1], b[Maxn << 1];
	for (int i = 0; i < A.size(); i ++) a[i] = A[i];
	for (int i = 0; i < B.size(); i ++) b[i] = B[i];
	Poly C;
	C.resize(std::min(lim, (int)(A.size() + B.size() - 1)));
	int n = 1; 
	for(n; n < A.size() + B.size() - 1; n <<= 1);
	NTT(n, a, 1);
	NTT(n, b, 1);
	px(n, a, b);
	NTT(n, a, 0);
	for (int i = 0; i < C.size(); i ++) {
		C[i] = a[i];
	}
	for (int i = 0; i <= n; i ++) {
		a[i] = 0;
		b[i] = 0;
	}
	return C;
}
void pinv(int n, const Poly &A, Poly &B) {
	if (n == 1) B.push_back(qmi(A[0]));
	else if (n & 1){
		pinv(-- n, A, B);
		int sav = 0;
		for (int i = 0; i < n; i ++) {
			sav = (sav + 1ll * B[i] * A[n - i] % mod) % mod;
		}
		B.push_back(1ll * sav * qmi(mod - A[0]) % mod);
	} else {
		pinv(n / 2, A, B);
		Poly sA;
		sA.resize(n);
		for (int i = 0; i < n; i ++) {
			sA[i] = A[i];
		}
		B = 2 * B - B * B * sA;
		B.resize(n);
	}
}
Poly pinv(const Poly &A) {	// P-inv
	Poly C;
	pinv(A.size(), A, C);
	return C;
}
int inv[Maxn];
void Init() {
	inv[1] = 1;
	for (int i = 2; i <= lim; i ++) {
		inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
	}
}
Poly dao(const Poly &A) {	// P-qiu-dao
	Poly C = A;
	for (int i = 1; i < C.size(); i ++) {
		C[i - 1] = 1ll * C[i] * i % mod;
	}
	C.pop_back();
	return C;
}
Poly ints(const Poly &A) { // P-ji-fen
	Poly C = A;
	for (int i = C.size() - 1; i; i --)
		C[i] = 1ll * C[i - 1] * inv[i] % mod;
	C[0] = 0;
	return C;
}
Poly ln(const Poly &A) {	// P-ln
	return ints(dao(A) * pinv(A));
}
void pexp(int n, const Poly &A, Poly &B) {
	if (n == 1) B.push_back(1);
	else if (n & 1) {
		pexp(n - 1, A, B);
		n -= 2;
	 	int sav = 0;
		for (int i = 0; i <= n; i ++) {
			sav = (sav + 1ll * (i + 1) * A[i + 1] % mod * B[n - i] % mod) % mod;
		}
		B.push_back(1ll * sav * inv[n + 1] % mod);
	} else {
		pexp(n / 2, A, B);
		Poly lnB = B;
		lnB.resize(n);
		lnB = ln(lnB);
		for (int i = 0; i < lnB.size(); i ++) {
			lnB[i] = (mod + A[i] - lnB[i]) % mod;
		}
		lnB[0] ++;
		B = B * lnB;
		B.resize(n);
	}
}
Poly pexp(const Poly &A) {	// P-exp
	Poly C;
	pexp(A.size(), A, C);
	return C;
}

using pii = std::array<int, 2>;
void solve() {
	int n, q;
    std::cin >> n >> q;
    std::vector<int> a(n + 1);
    for (int i = 1; i <= n; i ++) {
        std::cin >> a[i];
    }
    std::vector<int> ans(q + 1), qu(n + 1);
    for (int i = 1; i <= q; i ++) {
        std::cin >> qu[i];
    }
    int B = std::sqrt(n);
    std::vector<std::vector<int> > cir(B + 1, std::vector<int>(B + 1));
    std::vector<bool> vis(n + 1);
    for (int i = 1; i <= n; i ++) {
        if (vis[i]) continue;
        int j = i;
        std::vector<pii> st;
        while (! vis[j]) {
            st.push_back({j, a[j]});
            vis[j] = 1;
            j = a[j];
        }
        int siz = st.size();
        Poly f(siz * 2), g(siz);
        for (int i = 0; i < siz; i ++) {
            f[i] = st[i][0];
            f[i + siz] = f[i];
            // std::cout << f[i] << ' ';
        }
        // std::cout << '\n';
        for (int i = 0; i < siz; i ++) {
            g[i] = st[siz - i - 1][1];
            // std::cout << g[i] << ' ';
        }
        // std::cout << '\n';
        lim = 4 * siz;
        auto res = f * g;


        // std::cout << siz << '\n';
        // for (int j = 0; j <= siz + siz + 1; j ++) {
        //     std::cout << res[j] << ' ';
        // }
        // std::cout << '\n';

        if (st.size() > B) {
            for (int j = 1; j <= q; j ++) {
                int o = qu[j];
                o %= siz;
                ans[j] += res[siz + o]; // 1 2 3 0
            }
        } else {
            // cir[siz]
            for (int j = 0; j < siz; j ++) {
                cir[siz][j] += res[siz + j];
            }
        }
    }
    for (int i = 1; i <= q; i ++) {
        for (int j = 1; j <= B; j ++) {
            ans[i] += cir[j][qu[i] % j];
        }
    }
    for (int i = 1; i <= q; i ++) {
        std::cout << ans[i] << '\n';
    }
}

signed main() {
	std::ios::sync_with_stdio(false);
	std::cin.tie(nullptr);
	int t = 1;
	//std::cin >> t;
	while (t --) {
		solve();
	}
	return 0;
}
/*
4 4
2 4 1 3
1
2
3
4
*/

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 2ms
memory: 13860kb

input:

4 4
2 4 1 3
1
2
3
4

output:

25
20
25
30

result:

ok 4 lines

Test #2:

score: -100
Runtime Error

input:

3 6
2 3 1
1
2
3
999999998
999999999
1000000000

output:


result: