QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#253037#6325. Peaceful Resultsstd_abs#RE 241ms92692kbC++144.5kb2023-11-16 16:53:532023-11-16 16:53:53

Judging History

你现在查看的是最新测评结果

  • [2023-11-16 16:53:53]
  • 评测
  • 测评结果:RE
  • 用时:241ms
  • 内存:92692kb
  • [2023-11-16 16:53:53]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define pb push_back
#define all(a) a.begin(), a.end()
#define sz(a) ((int)a.size())
const int mod = 998244353, N = 1 << 21, G = 3;

ll add(ll a, ll b) {
    a += b;
    if (a >= mod) {
        a -= mod;
    }
    return a;
}

ll sub(ll a, ll b) {
    a -= b;
    if (a < 0) {
        a += mod;
    }
    return a;
}

ll mul(ll a, ll b) {
    return a * b % mod;
}

ll mpow(ll a, ll b) {
    ll ans = 1;
    for (; b; b >>= 1, a = mul(a, a)) {
        if (b & 1) {
            ans = mul(ans, a);
        }
    }
    return ans;
}

struct NTT {
    ll w[N];
    NTT() {
        ll dw = mpow(G, (mod - 1) / N);
        w[0] = 1;
        for (int i = 1; i < N; ++i) {
            w[i] = mul(w[i - 1], dw);
        }
    }
    void operator () (vector <ll> &a, bool inv = false) {
        int x = 0, n = a.size();
        for (int j = 1; j < n - 1; ++j) {
            for (int k = n >> 1; (x ^= k) < k; k >>= 1);
            if (j < x) {
                swap(a[x], a[j]);
            }
        }
        for (int L = 2; L <= n; L <<= 1) {
            int dx = N / L, dl = L >> 1;
            for (int i = 0; i < n; i += L) {
                for (int j = i, x = 0; j < i + dl; ++j, x += dx) {
                    ll tmp = mul(a[j + dl], w[x]);
                    a[j + dl] = sub(a[j], tmp);
                    a[j] = add(a[j], tmp);
                }
            }
        }
        if (inv) {
            reverse(1 + all(a));
            ll invn = mpow(n, mod - 2);
            for (int i = 0; i < n; ++i) a[i] = mul(a[i], invn);
        }
    }
} ntt;

vector <ll> Mul(vector <ll> a, vector <ll> b, int bound = N) {
    int m = a.size() + b.size() - 1, n = 1;
    while (n < m) n <<= 1;
    a.resize(n), b.resize(n);
    ntt(a), ntt(b);
    vector <ll> out(n);
    for (int i = 0; i < n; ++i) out[i] = mul(a[i], b[i]);
    ntt(out, true), out.resize(min(m, bound));
    return out;
}

const int M=N*3/2;
int fac[M],inv[M],ifac[M];
int n,ar,ap,as,br,bp,bs,cr,cp,cs;

int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    fac[0]=inv[1]=ifac[0]=1;
    for(int i=1; i<M; ++i) fac[i]=mul(fac[i-1],i);
    for(int i=2; i<M; ++i) inv[i]=mul(inv[mod%i],mod-mod/i);
    for(int i=1; i<M; ++i) ifac[i]=mul(ifac[i-1],inv[i]);
    cin >> n;
    cin >> ar >> ap >> as >> br >> bp >> bs >> cr >> cp >> cs;
    int t1=cr-cp-ap+ar-bp+br;
    if(t1%3){
        cout << "0\n";
        return 0;
    }
    t1/=3;
    int t2=ar+br-t1-cp;
    int base[3][3];
    base[0][0]=base[0][1]=0;
    base[1][1]=base[0][0]-t1;
    base[1][0]=t2-base[0][0]-base[0][1];
    base[0][2]=ar-base[0][0]-base[0][1];
    base[1][2]=ap-base[1][0]-base[1][1];
    base[2][0]=br-base[0][0]-base[1][0];
    base[2][1]=bp-base[0][1]-base[1][1];
    base[2][2]=as-base[2][0]-base[2][1];
    /*
    for(int i=0; i<3; ++i){
        for(int j=0; j<3; ++j) cout << base[i][j] << ' ';
        cout << "\n";
    }
    */
    int t[3];
    for(int i=0; i<3; ++i) t[i]=-min({base[0][i],base[1][(i+1)%3],base[2][(i+2)%3]});
    auto calc=[&](int i, int x){
        int res=1,tot=0;
        for(int j=0; j<3; ++j){
            int tmp=base[j][(j+i)%3]+x;
            if(tmp<0) return 0;
            tot+=tmp;
            if(tot>n) return -1;
            res=mul(res,ifac[tmp]);
        }
        return res;
    };
    auto count=[&](int i, int x){
        int tot=0;
        for(int j=0; j<3; ++j){
            int tmp=base[j][(j+i)%3]+x;
            if(tmp<0) return -1;
            tot+=tmp;
            if(tot>n) return -1;
        }
        return tot;
    };
    vector<ll> P[3];
    for(int j=0; j<3; ++j){
        for(int i=0; i<=n; ++i){
            int tmp=calc(j,t[j]+i);
            if(tmp<0) break;
            P[j].pb(tmp);
        }
        //for(auto i: P[j]) cout << i << ' ';
        //cout << endl;
    }
    //cout << t[0] << ' ' << t[1] << ' ' << t[2] << endl;
    vector<ll> Q=Mul(P[0],P[1]);
    int res=0;
    for(int i=0; i<sz(Q); ++i) if(Q[i]){
        //int u0=count(0,i+t[0]),u1=count(1,i+t[1]);
        //if(u0<0||u1<0) continue;
        int lft=n-count(0,t[0])-count(0,t[1])-i*3;
        int off=(lft-base[0][2]-base[1][0]-base[2][1]);
        assert(off%3==0);
        off/=3;
        int u2=count(2,off);
        if(u2<0) continue;
        res=add(res,mul(Q[i],P[2][off-t[2]]));
    }
    cout << mul(res,fac[n]) << "\n";
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 41ms
memory: 56660kb

input:

2
2 0 0
1 1 0
1 0 1

output:

2

result:

ok 1 number(s): "2"

Test #2:

score: 0
Accepted
time: 52ms
memory: 56680kb

input:

3
0 1 2
3 0 0
1 1 1

output:

0

result:

ok 1 number(s): "0"

Test #3:

score: 0
Accepted
time: 89ms
memory: 65108kb

input:

333333
111111 111111 111111
111111 111111 111111
111111 111111 111111

output:

383902959

result:

ok 1 number(s): "383902959"

Test #4:

score: 0
Accepted
time: 241ms
memory: 92692kb

input:

1500000
500000 500000 500000
500000 500000 500000
500000 500000 500000

output:

355543262

result:

ok 1 number(s): "355543262"

Test #5:

score: -100
Runtime Error

input:

1499999
499999 499999 500001
499999 499999 500001
499999 499999 500001

output:


result: