QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#508453#7632. Balanced ArrayspandapythonerAC ✓1030ms88692kbC++2315.4kb2024-08-07 15:44:202024-08-07 15:44:22

Judging History

你现在查看的是最新测评结果

  • [2024-08-07 15:44:22]
  • 评测
  • 测评结果:AC
  • 用时:1030ms
  • 内存:88692kb
  • [2024-08-07 15:44:20]
  • 提交

answer

#include <bits/stdc++.h>

using namespace std;


using ll = long long;

#define flt double
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define rep(i, n) for(int i = 0; i < n; i += 1)
#define len(a) ((int)(a).size())

mt19937 rnd(234);
const ll mod = 998244353;

ll bin_pow(ll x, ll n) {
    ll rs = 1;
    for (ll i = 1, a = x; i <= n; i *= 2, a = a * a % mod)
        if (n & i) rs = rs * a % mod;
    return rs;
}

ll inv(ll x) {
    return bin_pow(x, mod - 2);
}

namespace fft {
    int mxpw;
    int mxn;
    ll w;

    void build_w() {
        ll phi = mod - 1;
        ll f = phi;
        vector<ll> p;
        for (ll i = 2; i * i <= f; i += 1) {
            if (f % i == 0) {
                p.push_back(i);
                while (f % i == 0) {
                    f /= i;
                }
            }
        }
        if (f > 1) {
            p.push_back(f);
        }
        for (int i = 1; i < mod; i += 1) {
            bool ok = true;
            for (auto q : p) {
                if (bin_pow(i, phi / q) == 1) {
                    ok = false;
                    break;
                }
            }
            if (ok) {
                w = bin_pow(i, phi / (1 << mxpw));
                break;
            }
        }
    }

    vector<ll> rvx;

    void build_rvx(int n) {
        rvx.resize(n + 1);
        for (int i = 1; i <= n; i += 1) {
            rvx[i] = inv(i);
        }
    }

    vector<ll> rvi, wpws;

    void build(int _mxpw) {
        mxpw = _mxpw;
        mxn = (1 << mxpw);
        build_w();
        int n = (1 << mxpw);
        rvi.resize(n);
        rvi[0] = 0;
        for (int i = 1; i < n; i += 1) {
            rvi[i] = (rvi[i >> 1] >> 1);
            if (i & 1) {
                rvi[i] += (1 << (mxpw - 1));
            }
        }
        wpws.resize(n + 1);
        wpws[0] = 1;
        for (int i = 1; i <= n; i += 1) {
            wpws[i] = (wpws[i - 1] * w) % mod;
        }
        build_rvx(mxn);
    }

    void fft(vector<ll>& a, int nk) {
        int n = (1 << nk);
        for (int i = 0; i < n; i += 1) {
            int mrv = (rvi[i] >> (mxpw - nk));
            if (mrv < i) {
                swap(a[mrv], a[i]);
            }
        }
        for (int ln = 1; ln < n; ln *= 2) {
            int ln2 = ln + ln;
            for (int i = 0; i < n; i += ln2) {
                for (int j = 0; j < ln; j += 1) {
                    ll mw = wpws[mxn / ln2 * j];
                    int u = i + j;
                    int v = u + ln;
                    ll y = a[v] * mw % mod;
                    a[v] = a[u] - y;
                    if (a[v] < 0) {
                        a[v] += mod;
                    }
                    a[u] += y;
                    if (a[u] >= mod) {
                        a[u] -= mod;
                    }
                }
            }
        }
    }

    void rev_fft(vector<ll>& a, int nk) {
        int n = (1 << nk);
        fft(a, nk);
        ll rvn = inv(n);
        reverse(a.begin() + 1, a.end());
        for (int i = 0; i < n; i += 1) {
            a[i] = (a[i] * rvn) % mod;
        }
    }

    vector<ll> square(vector<ll> a) {
        int nk = 0;
        while ((1 << nk) < (int)a.size() + (int)a.size() - 1) {
            nk += 1;
        }
        int n = (1 << nk);
        a.resize(n, 0);
        fft(a, nk);
        for (int i = 0; i < n; i += 1) {
            a[i] = (a[i] * a[i]) % mod;
        }
        rev_fft(a, nk);
        while (!a.empty() && a.back() == 0) {
            a.pop_back();
        }
        return a;
    }

    vector<ll> mul(vector<ll> a, vector<ll> b) {
        int nk = 0;
        while ((1 << nk) < (int)a.size() + (int)b.size() - 1) {
            nk += 1;
        }
        int n = (1 << nk);
        a.resize(n, 0);
        b.resize(n, 0);
        fft(a, nk);
        fft(b, nk);
        for (int i = 0; i < n; i += 1) {
            a[i] = (a[i] * b[i]) % mod;
        }
        rev_fft(a, nk);
        while (!a.empty() && a.back() == 0) {
            a.pop_back();
        }
        return a;
    }

    void add_inplace(vector<ll>& a, const vector<ll>& b, ll k = 1) {
        a.resize(max(a.size(), b.size()), 0);
        for (int i = 0; i < (int)b.size(); i += 1) {
            a[i] = (a[i] + b[i] * k) % mod;
        }
    }

    vector<ll> add(vector<ll> a, const vector<ll>& b, ll k = 1) {
        a.resize(max(a.size(), b.size()), 0);
        for (int i = 0; i < (int)b.size(); i += 1) {
            a[i] = (a[i] + b[i] * k) % mod;
        }
        return a;
    }

    vector<ll> sub(vector<ll> a, const vector<ll>& b, ll k = 1) {
        a.resize(max(a.size(), b.size()), 0);
        for (int i = 0; i < (int)b.size(); i += 1) {
            a[i] = (a[i] + mod - b[i] * k % mod) % mod;
        }
        return a;
    }

    vector<ll> replace_x_slow(vector<ll>& a, const vector<ll>& b) {
        vector<ll> rs = {};
        vector<ll> bpw = { 1 };
        for (int i = 0; i < (int)a.size(); i += 1) {
            if (i > 0) {
                bpw = mul(bpw, b);
            }
            add_inplace(rs, bpw, a[i]);
        }
        return rs;
    }

    vector<ll> replace_x(vector<ll>& a, const vector<ll>& b) {
        vector<ll> rs = {};
        vector<ll> bpw = b;
        int n = a.size();
        vector<vector<ll>> d(n);
        for (int i = 0; i < n; i += 1) {
            d[i] = { a[i] };
        }
        while (n > 1) {
            int m = (n + 1) / 2;
            vector<vector<ll>> nd(m);
            for (int i = 0; i < n; i += 1) {
                if (i % 2 == 0) {
                    nd[i / 2] = d[i];
                } else {
                    add_inplace(nd[i / 2], mul(d[i], bpw));
                }
            }
            n = m;
            d.swap(nd);
            if (n != 1) {
                bpw = square(bpw);
            }
        }
        return d[0];
    }

    vector<ll> shift_x(vector<ll> a, ll t) {
        if (a.empty()) {
            return {};
        }
        int n = (int)a.size() - 1;
        vector<ll> f(n + 1), rf(n + 1);
        f[0] = rf[0] = 1;
        for (int i = 1; i <= n; i += 1) {
            f[i] = (f[i - 1] * i) % mod;
            rf[i] = inv(f[i]);
        }
        vector<ll> b(n + 1), c(n + 1);
        ll tpw = 1;
        for (int i = 0; i <= n; i += 1) {
            b[i] = (a[i] * tpw % mod * f[i] % mod);
            tpw = (tpw * t) % mod;
        }
        for (int i = 0; i <= n; i += 1) {
            c[n - i] = rf[i];
        }
        a = mul(b, c);
        vector<ll> d(n + 1);
        ll rvt = inv(t);
        ll rvt_pw = 1;
        for (int i = 0; i <= n; i += 1) {
            d[i] = rvt_pw * rf[i] % mod * a[i + n] % mod;
            rvt_pw = (rvt_pw * rvt) % mod;
        }
        return d;
    }

    vector<ll> rev_polynom(const vector<ll>& a, int n) {
        int sz = a.size();
        vector<ll> b = { inv(a[0]) };
        int m = 1;
        int mk = 0;
        while (m < n) {
            int m2 = m + m;
            int m4 = m2 + m2;
            b.resize(m4);
            fft(b, mk + 2);
            vector<ll> nb(m4);
            for (int i = 0; i < sz && i < m2; i += 1) {
                nb[i] = a[i];
            }
            fft(nb, mk + 2);
            for (int i = 0; i < m4; i += 1) {
                nb[i] = (2 * b[i] - nb[i] * b[i] % mod * b[i]) % mod;
                if (nb[i] < 0) {
                    nb[i] += mod;
                }
            }
            rev_fft(nb, mk + 2);
            nb.resize(m2);
            b.swap(nb);
            m = m2;
            mk += 1;
        }
        b.resize(n);
        return b;
    }

    vector<ll> square_root(const vector<ll>& a, int n) {
        ll sz = a.size();
        ll rv2 = inv(2);
        vector<ll> b = { 1 };
        int m = 1;
        while (m < n) {
            ll m2 = m + m;
            vector<ll> rvb = rev_polynom(b, m2);
            vector<ll> ab(m2);
            for (int i = 0; i < m2 && i < sz; i += 1) {
                ab[i] = a[i];
            }
            ab = mul(ab, rvb);
            ab.resize(m2);
            b.resize(m2);
            for (int i = 0; i < m2; i += 1) {
                b[i] = (rv2 * ((b[i] + ab[i]) % mod)) % mod;
            }
            m = m2;
        }
        b.resize(n);
        return b;
    }

    vector<ll> derivative(vector<ll> a) {
        int n = a.size();
        if (n == 0) {
            return {};
        }
        for (int i = 0; i + 1 < n; i += 1) {
            a[i] = (a[i + 1] * (i + 1)) % mod;
        }
        a.resize(n - 1);
        return a;
    }

    vector<ll> integrate(vector<ll> a) {
        int n = a.size();
        a.resize(n + 1);
        for (int i = n; i > 0; i -= 1) {
            a[i] = (a[i - 1] * rvx[i]) % mod;
        }
        a[0] = 0;
        return a;
    }

    vector<ll> sin_polynomial(int n) {
        vector<ll> a(n, 0);
        ll fct = 1;
        for (int i = 0; i < n; i += 1) {
            if (i != 0) {
                fct = (fct * i) % mod;
            }
            if (i % 2 == 1) {
                int sign = 1;
                if ((i / 2) % 2 == 1) {
                    sign = -1;
                }
                a[i] = (mod + sign * inv(fct)) % mod;
            }
        }
        return a;
    }

    vector<ll> cos_polynomial(int n) {
        vector<ll> a(n, 0);
        ll fct = 1;
        for (int i = 0; i < n; i += 1) {
            if (i != 0) {
                fct = (fct * i) % mod;
            }
            if (i % 2 == 0) {
                int sign = 1;
                if ((i / 2) % 2 == 1) {
                    sign = -1;
                }
                a[i] = (mod + sign * inv(fct)) % mod;
            }
        }
        return a;
    }

    vector<ll> super_cos_polynomial(int n, int k) {
        vector<ll> a(n, 0);
        ll fct = 1;
        for (int i = 0; i < n; i += 1) {
            if (i != 0) {
                fct = (fct * i) % mod;
            }
            if (i % k == 0) {
                int sign = 1;
                if ((i / k) % 2 == 1) {
                    sign = -1;
                }
                a[i] = (mod + sign * inv(fct)) % mod;
            }
        }
        return a;
    }

    vector<ll> logarithm(const vector<ll>& a, int n) {
        if (n == 0) {
            return {};
        }
        vector<ll> b = integrate(mul(derivative(a), rev_polynom(a, n)));
        b.resize(n);
        return b;
    }

    vector<ll> exponent(const vector<ll>& a, int n) {
        vector<ll> b = { 1 };
        int m = 1;
        while (m < n) {
            int m2 = m + m;
            vector<ll> t = logarithm(b, m2);
            for (int i = 0; i < m2 && i < (int)a.size(); i += 1) {
                t[i] = (t[i] - a[i]);
                if (t[i] < 0) {
                    t[i] += mod;
                }
            }
            vector<ll> q = fft::mul(t, b);
            q.resize(m2);
            b.resize(m2);
            for (int i = 0; i < m2; i += 1) {
                b[i] -= q[i];
                if (b[i] < 0) {
                    b[i] += mod;
                }
            }
            m = m2;
        }
        b.resize(n);
        return b;
    }

    vector<ll> solve_differential(const vector<ll>& a, const vector<ll>& b, int n) {
        vector<ll> e = exponent(integrate(a), n);
        vector<ll> result = mul(e, integrate(mul(b, rev_polynom(e, n))));
        result.resize(n);
        return result;
    }

    vector<ll> pure_exponent(int n, ll k = 1) {
        if (n == 0) {
            return {};
        }
        k %= mod;
        if (k < 0) {
            k += mod;
        }
        vector<ll> rs(n);
        rs[0] = 1;
        ll rv_fct = 1;
        for (int i = 1; i < n; i += 1) {
            rv_fct = (rv_fct * rvx[i]) % mod * k % mod;
            rs[i] = rv_fct;
        }
        return rs;
    }
}  // namespace fft



vector<ll> f, invf;

void build_f(int n) {
    f.resize(n + 1);
    invf.resize(n + 1);
    f[0] = invf[0] = 1;
    for (int i = 1; i <= n; i += 1) {
        f[i] = (f[i - 1] * i) % mod;
        invf[i] = inv(f[i]);
    }
}

ll cnk(ll n, ll k) {
    if (n < 0 or k < 0 or k > n) return 0;
    return (f[n] * invf[k] % mod * invf[n - k] % mod);
}

ll catalan(ll n) {
    return cnk(2 * n, n) * inv(n + 1) % mod;
}


ll choose_repeating(ll n, ll k) {
    return cnk(n + k - 1, k);
}


ll solve_smart(int n, int m) {
    build_f(n + m + 1e5);
    ll result = 1;
    for (int t = 1; t <= m; t += 1) {
        result = (result + choose_repeating(2 * t + 1, n - 1)) % mod;
        for (int cnt = 1; 2 * cnt + 1 <= n and cnt <= t - 1; cnt += 1) {
            ll coeff = (cnk(t, cnt) * cnk(t - 1, cnt) % mod * 2 +
                cnk(t, cnt + 1) * cnk(t - 1, cnt) % mod +
                cnk(t - 1, cnt + 1) * cnk(t, cnt)) % mod;
            coeff = coeff * inv(2 * t + 1) % mod;
            result = (result + coeff * cnk(2 * t + n - 2 * cnt - 1, n - 2 * cnt - 1)) % mod;
        }
    }
    return result;
}




ll solve(int n, int m) {
    build_f(2 * n + 2 * m + 1e5);
    ll result = 1;
    for (int t = 1; t <= m; t += 1) {
        result = (result + choose_repeating(2 * t + 1, n - 1)) % mod;
    }
    for (int mask = 0; mask <= 2; mask += 1) {
        int x = mask % 2;
        int y = (mask / 2) % 2;
        int mx_cnt = (n - 1) / 2;
        vector<ll> a(mx_cnt + 1);
        for (int cnt = 1; cnt <= mx_cnt; cnt += 1) {
            a[cnt] = invf[cnt + x] * invf[cnt + y] % mod;
            if (mask == 0) {
                a[cnt] = (a[cnt] * 2) % mod;
            }
            a[cnt] = a[cnt] * invf[n - 2 * cnt - 1] % mod;
        }
        vector<ll> b(m);
        for (int t_cnt = y + 1; t_cnt <= m - 1; t_cnt += 1) {
            b[t_cnt] = invf[t_cnt - x] * invf[t_cnt - 1 - y] % mod;
            b[t_cnt] = b[t_cnt] * f[n - 1 + 2 * t_cnt] % mod;
        }
        auto c = fft::mul(a, b);
        for (int t = 1; t <= m and t < len(c); t += 1) {
            result += c[t] * f[t] % mod * f[t - 1] % mod * invf[2 * t] % mod * inv(2 * t + 1) % mod;
        }
        result %= mod;
    }
    return result;
}



ll solve_slow(int n, int m) {
    vector<vector<vector<ll>>> dp(n + 1, vector<vector<ll>>(m + 1, vector<ll>(m + 1)));
    rep(i, m + 1) dp[0][0][m] = 1;
    rep(i, n) {
        rep(x, m + 1) rep(y, m + 1 - x) {
            rep(k, x) dp[i + 1][x][y] += dp[i][k][y];
            for (int k = m; k >= y; k -= 1) dp[i + 1][x][y] += dp[i][x][k];
            dp[i + 1][x][y] %= mod;
        }
    }
    ll result = 0;
    rep(x, m + 1) rep(y, m + 1 - x) result += dp[n][x][y];
    result %= mod;
    return result;
}


void stress() {
    int c = 0;
    while (1) {
        cout << ++c << "\n";
        int n = rnd() % 100 + 1;
        int m = rnd() % 100 + 1;
        auto my_rs = solve(n, m);
        auto right_rs = solve_slow(n, m);
        if (my_rs != right_rs) {
            cout << n << " " << m << "\n";
            break;
        }
    }
}


int32_t main() {
    fft::build(20);
    // stress();
    if (1) {
        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);
    }
    int n, m;
    cin >> n >> m;
    cout << solve(n, m) << "\n";
    return 0;
}

这程序好像有点Bug,我给组数据试试?

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 132ms
memory: 29444kb

input:

2 2

output:

9

result:

ok 1 number(s): "9"

Test #2:

score: 0
Accepted
time: 1028ms
memory: 88676kb

input:

500000 500000

output:

984531374

result:

ok 1 number(s): "984531374"

Test #3:

score: 0
Accepted
time: 332ms
memory: 54260kb

input:

500000 1

output:

219705876

result:

ok 1 number(s): "219705876"

Test #4:

score: 0
Accepted
time: 456ms
memory: 62260kb

input:

1 500000

output:

500001

result:

ok 1 number(s): "500001"

Test #5:

score: 0
Accepted
time: 953ms
memory: 81840kb

input:

500000 353535

output:

33730077

result:

ok 1 number(s): "33730077"

Test #6:

score: 0
Accepted
time: 985ms
memory: 83052kb

input:

353535 500000

output:

182445298

result:

ok 1 number(s): "182445298"

Test #7:

score: 0
Accepted
time: 1015ms
memory: 88684kb

input:

499999 499999

output:

933220940

result:

ok 1 number(s): "933220940"

Test #8:

score: 0
Accepted
time: 1030ms
memory: 88692kb

input:

499999 499998

output:

899786345

result:

ok 1 number(s): "899786345"

Test #9:

score: 0
Accepted
time: 925ms
memory: 79208kb

input:

377773 400009

output:

206748715

result:

ok 1 number(s): "206748715"

Test #10:

score: 0
Accepted
time: 520ms
memory: 61820kb

input:

499999 100001

output:

482775773

result:

ok 1 number(s): "482775773"

Test #11:

score: 0
Accepted
time: 971ms
memory: 86092kb

input:

444445 488884

output:

70939759

result:

ok 1 number(s): "70939759"

Test #12:

score: 0
Accepted
time: 1003ms
memory: 85576kb

input:

488885 444449

output:

591315327

result:

ok 1 number(s): "591315327"

Test #13:

score: 0
Accepted
time: 337ms
memory: 54192kb

input:

500000 111

output:

313439156

result:

ok 1 number(s): "313439156"

Test #14:

score: 0
Accepted
time: 920ms
memory: 79628kb

input:

333333 444444

output:

403492103

result:

ok 1 number(s): "403492103"

Test #15:

score: 0
Accepted
time: 909ms
memory: 81232kb

input:

499994 343433

output:

334451699

result:

ok 1 number(s): "334451699"

Test #16:

score: 0
Accepted
time: 941ms
memory: 83776kb

input:

477774 411113

output:

63883341

result:

ok 1 number(s): "63883341"

Test #17:

score: 0
Accepted
time: 621ms
memory: 63952kb

input:

123456 432109

output:

238795570

result:

ok 1 number(s): "238795570"

Test #18:

score: 0
Accepted
time: 916ms
memory: 73000kb

input:

131331 467777

output:

834790039

result:

ok 1 number(s): "834790039"

Test #19:

score: 0
Accepted
time: 350ms
memory: 54256kb

input:

500000 2

output:

304727284

result:

ok 1 number(s): "304727284"

Test #20:

score: 0
Accepted
time: 133ms
memory: 29380kb

input:

1111 111

output:

98321603

result:

ok 1 number(s): "98321603"

Test #21:

score: 0
Accepted
time: 962ms
memory: 85156kb

input:

416084 493105

output:

916827025

result:

ok 1 number(s): "916827025"

Test #22:

score: 0
Accepted
time: 320ms
memory: 42024kb

input:

53888 138663

output:

57263952

result:

ok 1 number(s): "57263952"

Test #23:

score: 0
Accepted
time: 608ms
memory: 65892kb

input:

219161 382743

output:

304889787

result:

ok 1 number(s): "304889787"

Test #24:

score: 0
Accepted
time: 558ms
memory: 60140kb

input:

181392 318090

output:

12528742

result:

ok 1 number(s): "12528742"

Test #25:

score: 0
Accepted
time: 607ms
memory: 64496kb

input:

135930 422947

output:

554153000

result:

ok 1 number(s): "554153000"

Test #26:

score: 0
Accepted
time: 529ms
memory: 58372kb

input:

280507 210276

output:

812816587

result:

ok 1 number(s): "812816587"

Test #27:

score: 0
Accepted
time: 934ms
memory: 75344kb

input:

253119 420465

output:

124024302

result:

ok 1 number(s): "124024302"

Test #28:

score: 0
Accepted
time: 480ms
memory: 59544kb

input:

446636 97448

output:

150388382

result:

ok 1 number(s): "150388382"

Test #29:

score: 0
Accepted
time: 460ms
memory: 54636kb

input:

284890 126665

output:

786559507

result:

ok 1 number(s): "786559507"

Test #30:

score: 0
Accepted
time: 239ms
memory: 39928kb

input:

186708 28279

output:

607509026

result:

ok 1 number(s): "607509026"

Extra Test:

score: 0
Extra Test Passed