QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#125442 | #6317. XOR Tree Path | Energy_is_not_over# | WA | 1337ms | 52200kb | C++17 | 8.1kb | 2023-07-16 18:24:36 | 2023-07-16 18:24:39 |
Judging History
answer
//#pragma GCC optimize("Ofast", "unroll-loops")
//#pragma GCC target("sse", "sse2", "sse3", "ssse3", "sse4")
#ifdef __APPLE__
#include <iostream>
#include <cmath>
#include <algorithm>
#include <stdio.h>
#include <cstdint>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <bitset>
#include <map>
#include <queue>
#include <ctime>
#include <stack>
#include <set>
#include <list>
#include <random>
#include <deque>
#include <functional>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <complex>
#include <numeric>
#include <cassert>
#include <array>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#else
#include <bits/stdc++.h>
#endif
#define all(a) a.begin(),a.end()
#define len(a) (int)(a.size())
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
#define fi first
#define se second
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;
template<typename T>
bool umin(T &a, T b) {
if (b < a) {
a = b;
return true;
}
return false;
}
template<typename T>
bool umax(T &a, T b) {
if (a < b) {
a = b;
return true;
}
return false;
}
#if __APPLE__
#define D for (bool _FLAG = true; _FLAG; _FLAG = false)
#define LOG(...) print(#__VA_ARGS__" ::", __VA_ARGS__) << endl
template <class ...Ts> auto &print(Ts ...ts) { return ((cerr << ts << " "), ...); }
#else
#define D while (false)
#define LOG(...)
#endif
const int max_n = -1, inf = 1000111222;
const int mod = 998244353;
const int md = mod;
/*
* long long version
long long mul(long long a, long long b) {
long long x = (long double) a * b / mod;
long long y = (a * b - x * mod) % mod;
return y < 0 ? y + mod : y;
}
long long power(long long a, long long b) {
if (b == 0) {
return 1 % mod;
}
if (b % 2 == 0) {
return power(mul(a, a), b / 2);
}
return mul(a, power(a, b - 1));
}*/
inline void inc(int &x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
}
inline void dec(int &x, int y) {
x -= y;
if (x < 0) {
x += mod;
}
}
inline int mul(int x, int y) {
return (1LL * x * y) % mod;
}
int power(int a, int b) {
int res = 1 % mod;
while (b) {
if (b % 2) {
res = mul(res, a);
}
b /= 2;
a = mul(a, a);
}
return res;
}
int inv(int x) {
return power(x, mod - 2);
}
string str_fraction(int x, int n = 20) {
stringstream res;
pair<int, int> best(x, 1);
for (int j = 1; j < n; ++j) {
best = min(best, {mul(x, j), j});
best = min(best, {mod - mul(x, j), -j});
}
if (best.second < 0) {
best.first *= -1;
best.second *= -1;
}
res << best.first << "/" << best.second;
return res.str();
}
const int max_f = 2e6+10;
int f[max_f], rf[max_f];
void get_all_f() {
f[0] = rf[0] = 1;
for (int i = 1; i < max_f; ++i) {
f[i] = mul(i, f[i - 1]);
}
rf[max_f - 1] = inv(f[max_f - 1]);
for (int i = max_f - 2; i > 0; --i) {
rf[i] = mul(i + 1, rf[i + 1]);
}
}
int get_c(int n, int k) {
if (n < k) {
return 0;
}
return mul(f[n], mul(rf[k], rf[n - k]));
}
template<int mod>
struct NTT {
static constexpr int max_lev = __builtin_ctz(mod - 1);
int prod[2][max_lev - 1];
NTT() {
int root = (mod == 998244353) ? 31 : find_root();
int rroot = inv(root);
vector<vector<int>> roots(2, vector<int>(max_lev - 1));
roots[0][max_lev - 2] = root;
roots[1][max_lev - 2] = rroot;
for (int tp = 0; tp < 2; ++tp) {
for (int i = max_lev - 3; i >= 0; --i) {
roots[tp][i] = mul(roots[tp][i + 1], roots[tp][i + 1]);
}
}
for (int tp = 0; tp < 2; ++tp) {
int cur = 1;
for (int i = 0; i < max_lev - 1; ++i) {
prod[tp][i] = mul(cur, roots[tp][i]);
cur = mul(cur, roots[tp ^ 1][i]);
}
}
}
template<bool inverse>
void fft(int *a, int lg) const {
const int n = 1 << lg;
int pos = max_lev - 1;
for (int it = 0; it < lg; ++it) {
const int h = inverse ? lg - 1 - it : it;
const int shift = (1 << (lg - h - 1));
int coef = 1;
for (int start = 0; start < (1 << h); ++start) {
for (int i = start << (lg - h); i < (start << (lg - h)) + shift; ++i) {
if (!inverse) {
const int y = mul(a[i + shift], coef);
a[i + shift] = a[i];
inc(a[i], y);
dec(a[i + shift], y);
} else {
const int y = mul(a[i] + mod - a[i + shift], coef);
inc(a[i], a[i + shift]);
a[i + shift] = y;
}
}
coef = mul(coef, prod[inverse][__builtin_ctz(~start)]);
}
}
if (inverse) {
const int rn = inv(n);
for (int i = 0; i < n; ++i) {
a[i] = mul(a[i], rn);
}
}
}
vector<int> product(vector<int> a, vector<int> b) const {
if (a.empty() || b.empty()) {
return {};
}
const int sz = a.size() + b.size() - 1;
const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
a.resize(n);
b.resize(n);
fft<false>(a.data(), lg);
fft<false>(b.data(), lg);
for (int i = 0; i < n; ++i) {
a[i] = mul(a[i], b[i]);
}
fft<true>(a.data(), lg);
a.resize(sz);
return a;
}
vector<int> square(vector<int> a) const {
if (a.empty()) {
return {};
}
const int sz = a.size() + a.size() - 1;
const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
a.resize(n);
fft<false>(a.data(), lg);
for (int i = 0; i < n; ++i) {
a[i] = mul(a[i], a[i]);
}
fft<true>(a.data(), lg);
a.resize(sz);
return a;
}
static int find_root() {
for (int root = 2; ; ++root) {
if (power(root, (1 << max_lev)) == 1 && power(root, (1 << (max_lev - 1))) != 1) {
return root;
}
}
}
static inline void inc(int &x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
}
static inline void dec(int &x, int y) {
x -= y;
if (x < 0) {
x += mod;
}
}
static inline int mul(int x, int y) {
return (1LL * x * y) % mod;
}
static int power(int x, int y) {
if (y == 0) {
return 1;
}
if (y % 2 == 0) {
return power(mul(x, x), y / 2);
}
return mul(x, power(x, y - 1));
}
static int inv(int x) {
return power(x, mod - 2);
}
};
NTT<mod> ntt;
int main() {
// freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
ios_base::sync_with_stdio(0);
cin.tie(0);
get_all_f();
int n,m;
cin>>n>>m;
n=m=1000;
vector<int> a(m+1);
for (int i=0;i<=m;i++){
a[i]=1ll*get_c(m,i)*get_c(m,i)%md*f[i]%md;
}
vector<int> res(n*m+1,0);
res[0]=1;
for (int i=0;i<10;i++){
if (n&(1ll<<i)){
res=ntt.product(res,a);
if (len(res)>n*m+1){
res.resize(n*m+1);
}
}
a=ntt.product(a,a);
if (len(a)>n*m+1){
a.resize(n*m+1);
}
}
int ans=0;
for (int i=0;i<=n*m;i++){
int cur=1ll*res[i]*f[n*m-i]%md;
if (i%2==0){
inc(ans,cur);
}
else{
inc(ans,md-cur);
}
}
cout<<ans<<"\n";
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 1337ms
memory: 52200kb
input:
5 1 0 0 1 0 1 2 1 3 3 4 3 5
output:
720037464
result:
wrong answer 1st numbers differ - expected: '5', found: '720037464'