QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#924895#9572. BingoHiraethsoulWA 0ms3712kbC++2014.1kb2025-03-03 19:34:162025-03-03 19:34:17

Judging History

This is the latest submission verdict.

  • [2025-03-03 19:34:17]
  • Judged
  • Verdict: WA
  • Time: 0ms
  • Memory: 3712kb
  • [2025-03-03 19:34:16]
  • Submitted

answer

#include <bits/stdc++.h>

#define int long long

const int P = 998244353; // 如果wa有可能需要换mod数

class Polynomial : public std::vector<int>
{
    static int multip(int a, int b)
    {
        return a * b % P;
    }
    // static int multip(int a, int b)
    // {
    //     int res = a * b - (int)(1.L * a * b / P) * P;
    //     res %= P;
    //     if (res < 0)
    //     {
    //         res += P;
    //     }
    //     return res;
    // }
    static int add(int a, int b)
    {
        a += b;
        a -= (a >= P ? P : 0);
        return a;
    }
    static int sub(int a, int b)
    {
        a -= b;
        a += (a < 0 ? P : 0);
        return a;
    }
    static int qmi(int a, int b)
    {
        int res = 1;
        while (b)
        {
            if (b & 1)
            {
                res = res * a % P;
            }
            a = a * a % P;
            b >>= 1;
        }
        return res;
    }
    static std::vector<int> w;
    static void initW(int _log)
    {
        const int r = 1 << _log;
        if (w.size() >= r)
        {
            return;
        }
        w.assign(r, 0);
        w[r >> 1] = 1;
        int s = qmi(3, (P - 1) >> _log); // 3是原根

        for (int i = r / 2 + 1; i < r; i++)
        {
            w[i] = w[i - 1] * s % P;
        }

        for (int i = r / 2 - 1; i > 0; i--)
        {
            w[i] = w[i * 2];
        }
    }

public:
    using std::vector<int>::vector;

    friend void dft(Polynomial &a)
    {
        const int n = a.size();
        for (int k = n >> 1; k; k >>= 1)
        {
            for (int i = 0; i < n; i += k << 1)
            {
                for (int j = 0; j < k; j++)
                {
                    int v = a[i + j + k];
                    a[i + j + k] = multip(sub(a[i + j], v), w[k + j]);
                    a[i + j] = add(a[i + j], v);
                }
            }
        }
    }
    friend void idft(Polynomial &a)
    {
        const int n = a.size();
        for (int k = 1; k < n; k <<= 1)
        {
            for (int i = 0; i < n; i += k << 1)
            {
                for (int j = 0; j < k; j++)
                {
                    int u = a[i + j];
                    int v = multip(a[i + j + k], w[j + k]);
                    a[i + j + k] = sub(u, v);
                    a[i + j] = add(u, v);
                }
            }
        }
        int val = P - (P - 1) / n;
        for (int i = 0; i < n; i++)
        {
            a[i] = multip(a[i], val);
        }
        std::reverse(a.begin() + 1, a.end());
    }

    friend Polynomial operator*(Polynomial a, Polynomial b)
    {
        if (a.size() == 0 or b.size() == 0)
        {
            return Polynomial();
        }
        int n = a.size() + b.size() - 1;
        int _log = std::__lg(2 * n - 1);
        int s = 1 << _log;
        if (((P - 1) & (s - 1)) != 0 or std::min(a.size(), b.size()) < 128)
        {
            Polynomial res(n);
            for (int i = 0; i < a.size(); i++)
            {
                for (int j = 0; j < b.size(); j++)
                {
                    res[i + j] = add(res[i + j], multip(a[i], b[j]));
                }
            }
            return res;
        }

        initW(_log);
        a.resize(s), b.resize(s);
        dft(a), dft(b);
        for (int i = 0; i < s; i++)
        {
            a[i] = multip(a[i], b[i]);
        }
        idft(a);
        return a.resize(n), a;
    }
    friend Polynomial deriv(const Polynomial &a) // 求导
    {
        int n = a.size();
        if (n <= 1)
        {
            return Polynomial();
        }
        Polynomial p(n - 1);
        for (int i = 1; i < n; i++)
        {
            p[i - 1] = multip(i, a[i]);
        }
        return p;
    }
    friend Polynomial integr(const Polynomial &a)
    {
        int n = a.size();
        Polynomial p(n + 1);

        std::vector<int> _inv(n + 1);
        _inv[1] = 1;
        for (int i = 2; i <= n; i++)
        {
            _inv[i] = multip(_inv[P % i], (P - P / i));
        }
        for (int i = 0; i < n; ++i)
        {
            p[i + 1] = multip(a[i], _inv[i + 1]);
        }
        return p;
    }
    friend Polynomial inv(const Polynomial &a)
    {
        int n = a.size();
        if (n == 1)
        {
            return {qmi(a[0], P - 2)};
        }
        Polynomial half(a.begin(), a.begin() + (n + 1) / 2);
        Polynomial b = inv(half), c = a * b;
        for (auto &x : c)
        {
            x = (x == 0 ? 0 : P - x); // ?
        }
        c[0] = add(c[0], 2);
        c = c * b;
        c.resize(n);
        return c;
    }

    friend Polynomial ln(const Polynomial &a)
    {
        int n = a.size();

        Polynomial b(n, 0);
        for (int i = 1; i < n; i++)
        {
            b[i - 1] = multip(i, a[i]);
        }
        b = b * inv(a);
        b.resize(n);

        std::vector<int> _inv(n);
        _inv[1] = 1;
        for (int i = 2; i < n; i++)
        {
            _inv[i] = multip(P - P / i, _inv[P % i]);
        }
        for (int i = n - 1; i; i--)
        {
            b[i] = multip(b[i - 1], _inv[i]);
        }
        b[0] = 0;
        return b;
    }

    friend Polynomial exp(const Polynomial &a)
    {
        int n = a.size();
        if (n == 1)
        {
            return {1};
        }
        Polynomial half(a.begin(), a.begin() + (n + 1) / 2);
        Polynomial b = exp(half);
        b.resize(n);
        Polynomial c = ln(b);
        for (int i = 0; i < n; i++)
        {
            c[i] = sub(a[i], c[i]);
        }
        c[0] = add(c[0], 1);
        c = c * b;
        c.resize(n);
        return c;
    }

    friend Polynomial power(Polynomial &F, std::string s)
    {
        int k1 = 0;
        int k2 = 0;
        int k3 = 0;
        for (int i = 0; i < s.length(); ++i)
        {
            k1 = (k1 * 10 + s[i] - '0') % P;
            k2 = (k2 * 10 + s[i] - '0') % (P - 1);
            if (i < 7)
            {
                k3 = k3 * 10 + s[i] - '0';
            }
        }
        int n = F.size();
        if (!F[0] and k3 >= F.size())
        {
            F.assign(n, 0);
            return F;
        }
        int pos = 0; // 处理移位
        for (int i = 0; i < n; ++i)
        {
            if (F[i])
            {
                pos = i;
                break;
            }
        }
        if (pos)
        {
            for (int i = pos; i < n; ++i)
            {
                F[i - pos] = F[i];
                F[i] = 0;
            }
        }
        int val = F[0];
        int cur = qmi(val, P - 2);
        for (int i = 0; i < F.size(); ++i)
        {
            F[i] = F[i] * cur % P;
        }
        F = ln(F);
        for (int i = 0; i < F.size(); ++i)
        {
            F[i] = F[i] * k1 % P;
        }
        F = exp(F);
        cur = qmi(val, k2);
        for (int i = 0; i < F.size(); ++i)
        {
            F[i] = F[i] * cur % P;
        }
        if (pos)
        {
            pos = std::min(1ll * pos * k1, n);
            for (int i = n - 1; i >= 0; --i)
            {
                if (i + pos < n)
                {
                    F[i + pos] = F[i];
                }
                F[i] = 0;
            }
        }
        return F;
    }

    friend Polynomial power(const Polynomial &F, int b, int m) // m是目标的区间长度
    {
        Polynomial res = {1};
        Polynomial G = F;
        while (b)
        {
            if (b & 1)
            {
                res = res * G;
                if (res.size() > m)
                {
                    res.resize(m);
                }
            }
            G = G * G;
            if (G.size() > m)
            {
                G.resize(m);
            }
            b >>= 1;
        }
        return res;
    }
};

std::vector<int> Polynomial::w;
using Poly = Polynomial;

int qmi(int a, int b)
{
    int res = 1;
    while (b)
    {
        if (b & 1)
        {
            res = res * a % P;
        }
        a = a * a % P;
        b >>= 1;
    }
    return res;
}
using i64 = long long;
template <class T>
constexpr T power(T a, i64 b)
{
    T res{1};
    for (; b; b /= 2, a *= a)
    {
        if (b % 2)
        {
            res *= a;
        }
    }
    return res;
}

constexpr i64 mul(i64 a, i64 b, i64 p)
{
    i64 res = a * b - i64(1.L * a * b / p) * p;
    res %= p;
    if (res < 0)
    {
        res += p;
    }
    return res;
}

template <i64 P>
struct MInt
{
    i64 x;
    constexpr MInt() : x{0} {}
    constexpr MInt(i64 x) : x{norm(x % getMod())} {}

    static i64 Mod;
    constexpr static i64 getMod()
    {
        if (P > 0)
        {
            return P;
        }
        else
        {
            return Mod;
        }
    }
    constexpr static void setMod(i64 Mod_)
    {
        Mod = Mod_;
    }
    constexpr i64 norm(i64 x) const
    {
        if (x < 0)
        {
            x += getMod();
        }
        if (x >= getMod())
        {
            x -= getMod();
        }
        return x;
    }
    constexpr i64 val() const
    {
        return x;
    }
    constexpr MInt operator-() const
    {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const
    {
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) &
    {
        if (getMod() < (1ULL << 31))
        {
            x = x * rhs.x % (getMod());
        }
        else
        {
            x = mul(x, rhs.x, getMod());
        }
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) &
    {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) &
    {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) &
    {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a)
    {
        i64 v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a)
    {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs)
    {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs)
    {
        return lhs.val() != rhs.val();
    }
    friend constexpr bool operator<(MInt lhs, MInt rhs)
    {
        return lhs.val() < rhs.val();
    }
};
template <>
i64 MInt<0>::Mod = 998244353;

using Z = MInt<P>;

struct Comb
{
    int n;

    std::vector<Z> _fac;
    std::vector<Z> _invfac;
    std::vector<Z> _inv;

    Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}

    Comb(int n) : Comb()
    {
        init(n);
    }

    void init(int m)
    {
        m = std::min<int>(m, Z::getMod() - 1);
        if (m <= n)
            return;
        _fac.resize(m + 1);
        _invfac.resize(m + 1);
        _inv.resize(m + 1);

        for (int i = n + 1; i <= m; i++)
        {
            _fac[i] = _fac[i - 1] * i;
        }
        _invfac[m] = _fac[m].inv();
        for (int i = m; i > n; i--)
        {
            _invfac[i - 1] = _invfac[i] * i;
            _inv[i] = _invfac[i] * _fac[i - 1];
        }
        n = m;
    }

    Z fac(int m)
    {
        if (m > n)
            init(2 * m);
        return _fac[m];
    }
    Z invfac(int m)
    {
        if (m > n)
            init(2 * m);
        return _invfac[m];
    }
    Z inv(int m)
    {
        if (m > n)
            init(2 * m);
        return _inv[m];
    }
    Z A(int a, int b)
    {
        if (a < b or b < 0)
            return 0;
        return fac(a) * invfac(a - b);
    }
    Z C(int n, int m)
    {
        if (n < m || m < 0)
            return 0;
        return fac(n) * invfac(m) * invfac(n - m);
    }
} comb;
void solve()
{
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n * m + 1, 0);
    for (int i = 1; i <= n * m; ++i)
    {
        std::cin >> a[i];
    }
    std::sort(begin(a) + 1, end(a));
    std::vector<int> fac(n * m + 1, 0);
    std::vector<int> ifac(n * m + 1, 0);

    fac[0] = 1;
    for (int i = 1; i <= n * m; ++i)
    {
        fac[i] = fac[i - 1] * i % P;
    }
    ifac[n] = qmi(fac[n], P - 2);
    for (int i = n; i >= 1; i--)
    {
        ifac[i - 1] = i * ifac[i] % P;
    }
    Poly f(n * m + 1, 0);
    for (int i = 1; i <= n * m; ++i)
    {
        f[i] = a[i] * fac[i - 1] % P;
    }
    Poly g(n * m + 1, 0);
    for (int i = 0; i <= n * m; ++i)
    {
        g[i] = ifac[n * m - i];
    }
    f = f * g;

    std::vector<Z> c(n * m + 1, 0);
    for (int i = 0; i <= n * m; ++i)
    {
        c[i] = (Z)(fac[n * m - i] * i % P * f[n * m + i] % P);
    }

    Z ans = 0;
    for (int x = 0; x <= n; ++x)
    {
        for (int y = 0; y <= m; ++y)
        {
            int C = x * m + y * n - x * y;
            if ((x + y) & 1)
            {
                ans += comb.C(n, x) * comb.C(m, y) * c[C];
            }
            else
            {
                ans -= comb.C(n, x) * comb.C(m, y) * c[C];
            }
        }
    }
    std::cout << ans << '\n';
}
signed main()
{
    std::ios::sync_with_stdio(0);
    std::cin.tie(0);

    int T;
    std::cin >> T;
    while (T--)
    {
        solve();
    }
}

詳細信息

Test #1:

score: 0
Wrong Answer
time: 0ms
memory: 3712kb

input:

4
2 2
1 3 2 4
3 1
10 10 10
1 3
20 10 30
3 4
1 1 4 5 1 4 1 9 1 9 8 10

output:

56
60
998244233
356425214

result:

wrong answer 3rd numbers differ - expected: '60', found: '998244233'