QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#125442#6317. XOR Tree PathEnergy_is_not_over#WA 1337ms52200kbC++178.1kb2023-07-16 18:24:362023-07-16 18:24:39

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2023-07-16 18:24:39]
  • 评测
  • 测评结果:WA
  • 用时:1337ms
  • 内存:52200kb
  • [2023-07-16 18:24:36]
  • 提交

answer

//#pragma GCC optimize("Ofast", "unroll-loops")
//#pragma GCC target("sse", "sse2", "sse3", "ssse3", "sse4")

#ifdef __APPLE__
#include <iostream>
#include <cmath>
#include <algorithm>
#include <stdio.h>
#include <cstdint>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <bitset>
#include <map>
#include <queue>
#include <ctime>
#include <stack>
#include <set>
#include <list>
#include <random>
#include <deque>
#include <functional>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <complex>
#include <numeric>
#include <cassert>
#include <array>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#else
#include <bits/stdc++.h>
#endif

#define all(a) a.begin(),a.end()
#define len(a) (int)(a.size())
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
#define fi first
#define se second

using namespace std;

typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;

template<typename T>
bool umin(T &a, T b) {
    if (b < a) {
        a = b;
        return true;
    }
    return false;
}
template<typename T>
bool umax(T &a, T b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}

#if __APPLE__
#define D for (bool _FLAG = true; _FLAG; _FLAG = false)
#define LOG(...) print(#__VA_ARGS__" ::", __VA_ARGS__) << endl
template <class ...Ts> auto &print(Ts ...ts) { return ((cerr << ts << " "), ...); }
#else
#define D while (false)
#define LOG(...)
#endif

const int max_n = -1, inf = 1000111222;

const int mod = 998244353;
const int md = mod;

/*
 * long long version

long long mul(long long a, long long b) {
    long long x = (long double) a * b / mod;
    long long y = (a * b - x * mod) % mod;
    return y < 0 ? y + mod : y;
}

long long power(long long a, long long b) {
    if (b == 0) {
        return 1 % mod;
    }
    if (b % 2 == 0) {
        return power(mul(a, a), b / 2);
    }
    return mul(a, power(a, b - 1));
}*/

inline void inc(int &x, int y) {
    x += y;
    if (x >= mod) {
        x -= mod;
    }
}

inline void dec(int &x, int y) {
    x -= y;
    if (x < 0) {
        x += mod;
    }
}

inline int mul(int x, int y) {
    return (1LL * x * y) % mod;
}

int power(int a, int b) {
    int res = 1 % mod;
    while (b) {
        if (b % 2) {
            res = mul(res, a);
        }
        b /= 2;
        a = mul(a, a);
    }
    return res;
}

int inv(int x) {
    return power(x, mod - 2);
}

string str_fraction(int x, int n = 20) {
    stringstream res;
    pair<int, int> best(x, 1);
    for (int j = 1; j < n; ++j) {
        best = min(best, {mul(x, j), j});
        best = min(best, {mod - mul(x, j), -j});
    }
    if (best.second < 0) {
        best.first *= -1;
        best.second *= -1;
    }
    res << best.first << "/" << best.second;
    return res.str();
}

const int max_f = 2e6+10;

int f[max_f], rf[max_f];

void get_all_f() {
    f[0] = rf[0] = 1;
    for (int i = 1; i < max_f; ++i) {
        f[i] = mul(i, f[i - 1]);
    }
    rf[max_f - 1] = inv(f[max_f - 1]);
    for (int i = max_f - 2; i > 0; --i) {
        rf[i] = mul(i + 1, rf[i + 1]);
    }
}

int get_c(int n, int k) {
    if (n < k) {
        return 0;
    }
    return mul(f[n], mul(rf[k], rf[n - k]));
}

template<int mod>
struct NTT {
    static constexpr int max_lev = __builtin_ctz(mod - 1);

    int prod[2][max_lev - 1];

    NTT() {
        int root = (mod == 998244353) ? 31 : find_root();
        int rroot = inv(root);
        vector<vector<int>> roots(2, vector<int>(max_lev - 1));
        roots[0][max_lev - 2] = root;
        roots[1][max_lev - 2] = rroot;
        for (int tp = 0; tp < 2; ++tp) {
            for (int i = max_lev - 3; i >= 0; --i) {
                roots[tp][i] = mul(roots[tp][i + 1], roots[tp][i + 1]);
            }
        }
        for (int tp = 0; tp < 2; ++tp) {
            int cur = 1;
            for (int i = 0; i < max_lev - 1; ++i) {
                prod[tp][i] = mul(cur, roots[tp][i]);
                cur = mul(cur, roots[tp ^ 1][i]);
            }
        }
    }

    template<bool inverse>
    void fft(int *a, int lg) const {
        const int n = 1 << lg;
        int pos = max_lev - 1;
        for (int it = 0; it < lg; ++it) {
            const int h = inverse ? lg - 1 - it : it;
            const int shift = (1 << (lg - h - 1));
            int coef = 1;
            for (int start = 0; start < (1 << h); ++start) {
                for (int i = start << (lg - h); i < (start << (lg - h)) + shift; ++i) {
                    if (!inverse) {
                        const int y = mul(a[i + shift], coef);
                        a[i + shift] = a[i];
                        inc(a[i], y);
                        dec(a[i + shift], y);
                    } else {
                        const int y = mul(a[i] + mod - a[i + shift], coef);
                        inc(a[i], a[i + shift]);
                        a[i + shift] = y;
                    }
                }
                coef = mul(coef, prod[inverse][__builtin_ctz(~start)]);
            }
        }
        if (inverse) {
            const int rn = inv(n);
            for (int i = 0; i < n; ++i) {
                a[i] = mul(a[i], rn);
            }
        }
    }

    vector<int> product(vector<int> a, vector<int> b) const {
        if (a.empty() || b.empty()) {
            return {};
        }
        const int sz = a.size() + b.size() - 1;
        const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
        a.resize(n);
        b.resize(n);
        fft<false>(a.data(), lg);
        fft<false>(b.data(), lg);
        for (int i = 0; i < n; ++i) {
            a[i] = mul(a[i], b[i]);
        }
        fft<true>(a.data(), lg);
        a.resize(sz);
        return a;
    }

    vector<int> square(vector<int> a) const {
        if (a.empty()) {
            return {};
        }
        const int sz = a.size() + a.size() - 1;
        const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
        a.resize(n);
        fft<false>(a.data(), lg);
        for (int i = 0; i < n; ++i) {
            a[i] = mul(a[i], a[i]);
        }
        fft<true>(a.data(), lg);
        a.resize(sz);
        return a;
    }

    static int find_root() {
        for (int root = 2; ; ++root) {
            if (power(root, (1 << max_lev)) == 1 && power(root, (1 << (max_lev - 1))) != 1) {
                return root;
            }
        }
    }

    static inline void inc(int &x, int y) {
        x += y;
        if (x >= mod) {
            x -= mod;
        }
    }

    static inline void dec(int &x, int y) {
        x -= y;
        if (x < 0) {
            x += mod;
        }
    }

    static inline int mul(int x, int y) {
        return (1LL * x * y) % mod;
    }

    static int power(int x, int y) {
        if (y == 0) {
            return 1;
        }
        if (y % 2 == 0) {
            return power(mul(x, x), y / 2);
        }
        return mul(x, power(x, y - 1));
    }

    static int inv(int x) {
        return power(x, mod - 2);
    }
};

NTT<mod> ntt;

int main() {
//    freopen("input.txt", "r", stdin);
//    freopen("output.txt", "w", stdout);

    ios_base::sync_with_stdio(0);
    cin.tie(0);

    get_all_f();

    int n,m;
    cin>>n>>m;
n=m=1000;
    vector<int> a(m+1);
    for (int i=0;i<=m;i++){
        a[i]=1ll*get_c(m,i)*get_c(m,i)%md*f[i]%md;
    }
    vector<int> res(n*m+1,0);
    res[0]=1;
    for (int i=0;i<10;i++){
        if (n&(1ll<<i)){
            res=ntt.product(res,a);
            if (len(res)>n*m+1){
                res.resize(n*m+1);
            }
        }
        a=ntt.product(a,a);
        if (len(a)>n*m+1){
            a.resize(n*m+1);
        }
    }
    int ans=0;
    for (int i=0;i<=n*m;i++){
        int cur=1ll*res[i]*f[n*m-i]%md;
        if (i%2==0){
            inc(ans,cur);
        }
        else{
            inc(ans,md-cur);
        }
    }
    cout<<ans<<"\n";
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 1337ms
memory: 52200kb

input:

5
1 0 0 1 0
1 2
1 3
3 4
3 5

output:

720037464

result:

wrong answer 1st numbers differ - expected: '5', found: '720037464'