QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#527200#6323. Range NEQsolar_express#WA 58ms26208kbC++231.9kb2024-08-22 11:53:362024-08-22 11:53:36

Judging History

你现在查看的是最新测评结果

  • [2024-08-22 11:53:36]
  • 评测
  • 测评结果:WA
  • 用时:58ms
  • 内存:26208kb
  • [2024-08-22 11:53:36]
  • 提交

answer

#include <algorithm>
#include <cstdint>
#include <iostream>
const int N = 1<<20;
using namespace std;
using u64 = uint64_t;
const int mod = 998244353;
int rev[N], wn[N], lim, invlim;
int pow(int a, int b, int ans = 1) {
	for(;b;b >>= 1, a = (u64) a * a % mod) if(b & 1)
		ans = (u64) ans * a % mod;
	return ans;
}
void init(int len) {
	lim = 2 << std::__lg(len - 1);
	invlim = mod - (mod - 1) / lim;
	for(static int i = 1;i < lim;i += i) {
		wn[i] = 1;
		const int w = pow(3, mod / i / 2);
		for(int j = 1;j < i;++j) {
			wn[i + j] = (u64) wn[i + j - 1] * w % mod;
		}
	}
	for(int i = 1;i < lim;++i) {
		rev[i] = rev[i >> 1] >> 1 | (i % 2u * lim / 2);
	}
}
void DFT(int * a) {
	static u64 t[N];
	for(int i = 0;i < lim;++i) t[i] = a[rev[i]];
	for(int i = 1;i < lim;i += i) {
		for(int k = i & (1 << 20);k--;) 
			if(t[k] >= mod * 9ull) t[k] -= mod * 9ull;
		for(int j = 0;j < lim;j += i + i) {
			for(int k = 0;k < i;++k) {
				const u64 x = t[i + j + k] * wn[i + k] % mod;
				t[i + j + k] = t[k + j] + (mod - x), t[k + j] += x;
			}
		}
	}
	for(int i = 0;i < lim;++i) a[i] = t[i] % mod;
}
void IDFT(int * a) {
	DFT(a), std::reverse(a + 1, a + lim);
	for(int i = 0;i < lim;++i)
		a[i] = (u64) a[i] * invlim % mod;
}

int fac[N];
int pol[N];

int main() {
    int n, m;
    cin >> n >> m;
    fac[0] = 1;
    for (int i = 1; i <= m*n; ++i) fac[i] = 1ll*fac[i-1]*i % mod;
    for (int k = 0; k <= m; ++k) {
        int binom = pow(fac[k], mod-2, pow(fac[m-k], mod-2, fac[m]));
        pol[k] = 1ll * binom * binom % mod * fac[k] % mod;
        if (k & 1) pol[k] = (mod-pol[k]) % mod;
    }
    init(N);
    DFT(pol);
    for (int i = 0; i < lim; ++i) pol[i] = pow(pol[i], n);
    IDFT(pol);
    int ans = 0;
    for (int k = 0; k <= n*m; ++k) {
        ans = (ans + 1ll * pol[k] * fac[n*m-k]) % mod;
    }
    cout << ans << '\n';
    return 0;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 58ms
memory: 26208kb

input:

2 2

output:

55977174

result:

wrong answer 1st numbers differ - expected: '4', found: '55977174'