QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#728610#9572. Bingoucup-team5062#TL 46ms11484kbC++204.1kb2024-11-09 15:31:462024-11-09 15:31:49

Judging History

你现在查看的是最新测评结果

  • [2024-11-09 15:31:49]
  • 评测
  • 测评结果:TL
  • 用时:46ms
  • 内存:11484kb
  • [2024-11-09 15:31:46]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

const long long mod = 998244353;
const long long base = 3;
long long Fact[1 << 20];
long long Inv[1 << 20];

long long modpow(long long a, long long b, long long m) {
	long long p = 1, q = a;
	for (int i = 0; i < 30; i++) {
		if ((b >> i) & 1) { p *= q; p %= m; }
		q *= q; q %= m;
	}
	return p;
}

long long Division(long long a, long long b, long long m) {
	return (a * modpow(b, m - 2, m)) % m;
}


// ================================================================================= FFT Library
vector<long long> ntt(vector<long long> vec, int typ) {
	long long root = (typ == 1 ? base : Division(1, base, mod));
	int size_ = vec.size();
	vector<long long> dat(size_, 0);

	// Step A. Reverse Order
	for (int i = 0; i < size_; i++) {
		int r1 = 1, r2 = size_ / 2, cur = 0;
		while (r2 >= 1) {
			if ((i & r1) != 0) { cur += r2; }
			r1 <<= 1;
			r2 >>= 1;
		}
		dat[cur] = vec[i];
	}

	// Step B. Calculation
	for (int b = 2; b <= size_; b *= 2) {
		vector<long long> pows(b, 1);
		pows[1] = modpow(root, (mod - 1) / b, mod);
		for (int i = 2; i < b; i++) pows[i] = (1LL * pows[1] * pows[i - 1]) % mod;

		// Main Part
		for (int stt = 0; stt < size_; stt += b) {
			for (int i = 0; i < b / 2; i++) {
				long long r1 = dat[stt + i] + pows[i + 0 * b / 2] * dat[stt + i + b / 2];
				long long r2 = dat[stt + i] + pows[i + 1 * b / 2] * dat[stt + i + b / 2];
				dat[stt + i + 0 * b / 2] = r1 % mod;
				dat[stt + i + 1 * b / 2] = r2 % mod;
			}
		}
	}

	// Step C. Finalize
	if (typ == 2) {
		long long mult = Division(1, size_, mod);
		for (int i = 0; i < size_; i++) dat[i] = (dat[i] * mult) % mod;
	}
	return dat;
}

vector<long long> convolution(vector<long long> A, vector<long long> B) {
	int size_ = 1;
	while (size_ < A.size() * B.size()) size_ *= 2;
	while (A.size() < size_) A.push_back(0);
	while (B.size() < size_) B.push_back(0);

	// First NTT
	vector<long long> r1 = ntt(A, 1);
	vector<long long> r2 = ntt(B, 1);
	vector<long long> r3(size_, 0);
	for (int i = 0; i < size_; i++) r3[i] = (r1[i] * r2[i]) % mod;

	// Second NTT
	return ntt(r3, 2);
}


// ================================================================================= Solve Function
void Initialize() {
	Fact[0] = 1;
	for (int i = 1; i <= 400000; i++) Fact[i] = (1LL * i * Fact[i - 1]) % mod;
	for (int i = 0; i <= 400000; i++) Inv[i] = Division(1, Fact[i], mod);
}

long long ncr(int n, int r) {
	if (n < r || r < 0) return 0;
	return (Fact[n] * Inv[r] % mod) * Inv[n - r] % mod;
}

long long Solve(int H, int W, vector<int> A) {
	vector<long long> param(H * W + 1, 0);
	sort(A.begin(), A.end());
	if (H == 1 || W == 1) {
		return (1LL * A[0] * Fact[H * W]) % mod;
	}

	// Step 1. Get Paramaters
	for (int i = 1; i <= H; i++) {
		for (int j = 1; j <= W; j++) {
			long long way1 = Fact[i * j];
			long long way2 = ncr(H, i);
			long long way3 = ncr(W, j);
			long long way4 = ((H + W - i - j) % 2 == 0 ? +1 : -1);
			long long c = (((way1 * way2) % mod) * ((way3 * way4) % mod)) % mod;
			c = (c + mod) % mod;
			param[i * j] += c;
			param[i * j] %= mod;
		}
	}

	// Step 2. FFT
	vector<long long> Vec1 = param;
	vector<long long> Vec2(H * W + 1, 0);
	for (int i = 0; i <= H * W; i++) Vec2[i] = Inv[H * W - i];
	vector<long long> Result = convolution(Vec1, Vec2);
	while (Result.size() <= 2 * H * W) Result.push_back(0);
	for (int i = 0; i <= H * W; i++) Result[H * W + i] = (Result[H * W + i] * Fact[H * W - i]) % mod;

	// Step 3. Get Answer
	long long ans = 0;
	for (int i = 1; i < H * W - 1; i++) {
		long long ways = (Result[2 * H * W - i] - Result[2 * H * W - i - 1] + mod) % mod;
		long long incr = A[i];
		ans += ways * incr;
		ans %= mod;
	}
	return ans;
}


// ================================================================================= Main Function
int main() {
	int T; cin >> T; Initialize();
	for (int t = 1; t <= T; t++) {
		int H, W; cin >> H >> W;
		vector<int> A(H * W, 0);
		for (int i = 0; i < H * W; i++) scanf("%d", &A[i]);
		cout << Solve(H, W, A) << endl;
	}
	return 0;
}

详细

Test #1:

score: 100
Accepted
time: 46ms
memory: 11484kb

input:

4
2 2
1 3 2 4
3 1
10 10 10
1 3
20 10 30
3 4
1 1 4 5 1 4 1 9 1 9 8 10

output:

56
60
60
855346687

result:

ok 4 number(s): "56 60 60 855346687"

Test #2:

score: 0
Accepted
time: 46ms
memory: 11440kb

input:

1
2 2
0 0 998244352 998244352

output:

998244345

result:

ok 1 number(s): "998244345"

Test #3:

score: -100
Time Limit Exceeded

input:

900
1 1
810487041
1 2
569006976 247513378
1 3
424212910 256484544 661426830
1 4
701056586 563296095 702740883 723333858
1 5
725786515 738465053 821758167 170452477 34260723
1 6
204184507 619884535 208921865 898995024 768400582 369477346
1 7
225635227 321139203 724076812 439129905 405005469 369864252...

output:

810487041
495026756
540662911
541929691
118309348
270925149
575366228
709974238
761347712
304011276
14811741
366145628
638305530
240546928
484276475
603344008
926633861
161768904
239961447
329781933
315752584
578075668
259941994
600147169
402221164
890998500
154285994
181862417
47930994
273729619
64...

result: