QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#659138#9476. 012 Griducup-team4435#AC ✓137ms15312kbC++2035.6kb2024-10-19 18:53:562024-10-19 18:53:57

Judging History

你现在查看的是最新测评结果

  • [2024-10-19 18:53:57]
  • 评测
  • 测评结果:AC
  • 用时:137ms
  • 内存:15312kb
  • [2024-10-19 18:53:56]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;

#define all(a) begin(a), end(a)
#define len(a) int((a).size())

/*
 ! WARNING: MOD must be prime.
 ! WARNING: MOD must be less than 2^30.
 * Use .get() to get the stored value.
 */
template<uint32_t mod>
class montgomery {
    static_assert(mod < uint32_t(1) << 30, "mod < 2^30");
    using mint = montgomery<mod>;

private:
    uint32_t value;

    static constexpr uint32_t inv_neg_mod = []() {
        uint32_t x = mod;
        for (int i = 0; i < 4; ++i) {
            x *= uint32_t(2) - mod * x;
        }
        return -x;
    }();
    static_assert(mod * inv_neg_mod == -1);

    static constexpr uint32_t neg_mod = (-uint64_t(mod)) % mod;

    static uint32_t reduce(const uint64_t &value) {
        return (value + uint64_t(uint32_t(value) * inv_neg_mod) * mod) >> 32;
    }

    inline static const mint ONE = mint(1);

public:
    montgomery() : value(0) {}
    montgomery(const mint &x) : value(x.value) {}

    template<typename T, typename U = std::enable_if_t<std::is_integral<T>::value>>
    montgomery(const T &x) : value(!x ? 0 : reduce(int64_t(x % int32_t(mod) + int32_t(mod)) * neg_mod)) {}

    static constexpr uint32_t get_mod() {
        return mod;
    }

    uint32_t get() const {
        auto real_value = reduce(value);
        return real_value < mod ? real_value : real_value - mod;
    }

    template<typename T>
    mint power(T degree) const {
        degree = (degree % int32_t(mod - 1) + int32_t(mod - 1)) % int32_t(mod - 1);
        mint prod = 1, a = *this;
        for (; degree > 0; degree >>= 1, a *= a)
            if (degree & 1)
                prod *= a;

        return prod;
    }

    mint inv() const {
        return power(-1);
    }

    mint& operator=(const mint &x) {
        value = x.value;
        return *this;
    }

    mint& operator+=(const mint &x) {
        if (int32_t(value += x.value - (mod << 1)) < 0) {
            value += mod << 1;
        }
        return *this;
    }

    mint& operator-=(const mint &x) {
        if (int32_t(value -= x.value) < 0) {
            value += mod << 1;
        }
        return *this;
    }

    mint& operator*=(const mint &x) {
        value = reduce(uint64_t(value) * x.value);
        return *this;
    }

    mint& operator/=(const mint &x) {
        return *this *= x.inv();
    }

    friend mint operator+(const mint &x, const mint &y) {
        return mint(x) += y;
    }

    friend mint operator-(const mint &x, const mint &y) {
        return mint(x) -= y;
    }

    friend mint operator*(const mint &x, const mint &y) {
        return mint(x) *= y;
    }

    friend mint operator/(const mint &x, const mint &y) {
        return mint(x) /= y;
    }

    mint& operator++() {
        return *this += ONE;
    }

    mint& operator--() {
        return *this -= ONE;
    }

    mint operator++(int) {
        mint prev = *this;
        *this += ONE;
        return prev;
    }

    mint operator--(int) {
        mint prev = *this;
        *this -= ONE;
        return prev;
    }

    mint operator-() const {
        return mint(0) - *this;
    }

    bool operator==(const mint &x) const {
        return get() == x.get();
    }

    bool operator!=(const mint &x) const {
        return get() != x.get();
    }

    bool operator<(const mint &x) const {
        return get() < x.get();
    }

    template<typename T>
    explicit operator T() {
        return get();
    }

    friend std::istream& operator>>(std::istream &in, mint &x) {
        std::string s;
        in >> s;
        x = 0;
        bool neg = s[0] == '-';
        for (const auto c : s)
            if (c != '-')
                x = x * 10 + (c - '0');

        if (neg)
            x *= -1;

        return in;
    }

    friend std::ostream& operator<<(std::ostream &out, const mint &x) {
        return out << x.get();
    }

    static int32_t primitive_root() {
        if constexpr (mod == 1'000'000'007)
            return 5;
        if constexpr (mod == 998'244'353)
            return 3;
        if constexpr (mod == 786433)
            return 10;

        static int root = -1;
        if (root != -1)
            return root;

        std::vector<int> primes;
        int value = mod - 1;
        for (int i = 2; i * i <= value; i++)
            if (value % i == 0) {
                primes.push_back(i);
                while (value % i == 0)
                    value /= i;
            }

        if (value != 1)
            primes.push_back(value);

        for (int r = 2;; r++) {
            bool ok = true;
            for (auto p : primes)
                if ((mint(r).power((mod - 1) / p)).get() == 1) {
                    ok = false;
                    break;
                }

            if (ok)
                return root = r;
        }
    }
};

// constexpr uint32_t MOD = 1'000'000'007;
constexpr uint32_t MOD = 998'244'353;
using mint = montgomery<MOD>;

/*
 ! WARNING: (MOD - 1) must be divisible by (2n), where n is max degree of polynomials.
 * Include static_modular_int or montgomery (faster) to use it.
 * Don't need to care about precomputing primitive root.
 * Use it like std::vector<mint> with extra methods.
 */
template<typename mint>
class polynom_t : public std::vector<mint> {
public:
    static constexpr int RANK = __builtin_ctz(mint::get_mod() - 1);
    static_assert(RANK >= 15, "MOD doesn't seem fft-friendly.");

    using value_type = mint;

    using std::vector<mint>::empty;
    using std::vector<mint>::back;
    using std::vector<mint>::pop_back;
    using std::vector<mint>::size;
    using std::vector<mint>::clear;
    using std::vector<mint>::begin;
    using std::vector<mint>::end;

private:
    static constexpr int EVAL_BRUTE_SIZE = 1 << 4;
    static constexpr int MUL_MIN_CUT = 20;
    static constexpr int MUL_MAX_CUT = 1 << 6;
    static constexpr int DIV_N_CUT = 1 << 7;
    static constexpr int DIV_M_CUT = 1 << 6;
    static constexpr int INV_BRUTE_FORCE_SIZE = 1 << 3;

    class fft_precalc {
    private:
        std::vector<mint> inv_{0, 1};

    public:
        mint root;
        // roots[x] = root.power((MOD - 1) / 2^x)
        // inv_roots[x] = roots[x].inv()
        std::array<mint, RANK> roots, inv_roots;
        // inv_l[x] = 2^{-x}
        std::array<mint, RANK> inv_l;
        // r_transition[x] = root.power((MOD - 1) * (3 - 2^{x + 1}) / 2^{x + 2})
        // inv_r_transition[x] = r_transition[x].inv()
        std::array<mint, RANK> r_transition, inv_r_transition;

        fft_precalc() {
            root = mint::primitive_root();
            for (int x = 0; x < RANK; x++) {
                roots[x] = mint(root).power((mint::get_mod() - 1) >> x);
                inv_roots[x] = roots[x].inv();
                inv_l[x] = mint(1 << x).inv();
                r_transition[x] = root.power(((mint::get_mod() - 1) >> (x + 2)) * (3 + (1 << (x + 1))));
                inv_r_transition[x] = r_transition[x].inv();
            }
        }

        mint inv(int x) {
            for (int i = inv_.size(); i <= x; i++) {
                inv_.push_back(-inv_[mint::get_mod() % i] * mint(mint::get_mod() / i));
            }
            return inv_[x];
        }
    };

    inline static fft_precalc fft_data;

public:
    /*
     * Let r be the primitive root of the given MOD.
     * Let rev(i) be the reversed i as a binary mask of size log(n).
     * Let R[i] = r.power((MOD - 1) / a.size() * rev(i)).
     * It replaces a[i] with a.eval(R[i]).
     */
    static void fft(polynom_t<mint> &a) {
        assert(!a.empty());
        int n = a.size(), l = __builtin_ctz(n);
        for (int len = 0; len < l; len++) {
            int m = n >> (len + 1);
            mint r = 1;
            for (int i = 0; i < (1 << len); i++) {
                // rot = r^((MOD - 1) / n * rev(2 * i))
                int id = i << (l - len);
                for (int j = 0; j < m; j++) {
                    auto u = a[id + j];
                    auto v = a[id + j + m] * r;
                    a[id + j] = u + v;
                    a[id + j + m] = u - v;
                }
                if (i + 1 < (1 << len)) {
                    r *= fft_data.r_transition[__builtin_ctz(~i)];
                }
            }
        }
    }

    // inv_fft(fft(a)) = a
    static void inv_fft(polynom_t<mint> &a) {
        assert(!a.empty());
        int n = a.size(), l = __builtin_ctz(n);
        for (int len = l - 1; len >= 0; len--) {
            int m = n >> (len + 1);
            mint ir = 1;
            for (int i = 0; i < (1 << len); i++) {
                int id = i << (l - len);
                for (int j = 0; j < m; j++) {
                    auto u = a[id + j];
                    auto v = a[id + j + m];
                    a[id + j] = u + v;
                    a[id + j + m] = (u - v) * ir;
                }
                if (i + 1 < (1 << len)) {
                    ir *= fft_data.inv_r_transition[__builtin_ctz(~i)];
                }
            }
        }
        for (auto &value : a) {
            value *= fft_data.inv_l[l];
        }
    }

private:
    // Completes fft assuming that a is the right half of the polynomial after the first iteration of the fft.
    static void right_half_fft(polynom_t<mint> &a) {
        assert(!a.empty());
        int n = a.size(), l = __builtin_ctz(n);
        for (int len = 0; len < l; len++) {
            int m = n >> (len + 1);
            mint r = fft_data.roots[len + 2];
            for (int i = 0; i < (1 << len); i++) {
                int id = i << (l - len);
                for (int j = 0; j < m; j++) {
                    auto u = a[id + j];
                    auto v = a[id + j + m] * r;
                    a[id + j] = u + v;
                    a[id + j + m] = u - v;
                }
                if (i + 1 < (1 << len)) {
                    r *= fft_data.r_transition[__builtin_ctz(~i)];
                }
            }
        }
    }

    // inv_right_half_fft(right_half_fft(a)) = a
    static void inv_right_half_fft(polynom_t<mint> &a) {
        assert(!a.empty());
        int n = a.size(), l = __builtin_ctz(n);
        for (int len = l - 1; len >= 0; len--) {
            int m = n >> (len + 1);
            mint ir = fft_data.inv_roots[len + 2];
            for (int i = 0; i < (1 << len); i++) {
                int id = i << (l - len);
                for (int j = 0; j < m; j++) {
                    auto u = a[id + j];
                    auto v = a[id + j + m];
                    a[id + j] = u + v;
                    a[id + j + m] = (u - v) * ir;
                }
                if (i + 1 < (1 << len)) {
                    ir *= fft_data.inv_r_transition[__builtin_ctz(~i)];
                }
            }
        }
        for (auto &value : a) {
            value *= fft_data.inv_l[l];
        }
    }

    class multipoint_evaluation_tree {
    private:
        int n_points, l;
        std::vector<polynom_t<mint>> segtree;

    public:
        multipoint_evaluation_tree(int n, const std::vector<mint> &points) : n_points(points.size()), l(1) {
            while ((1 << l) < n) {
                l++;
            }
            polynom_t<mint> aux;
            aux.reserve(1 << l);

            segtree.assign(l + 1, polynom_t<mint>(1 << (l + 1)));
            for (int i = 0; i < (1 << l); i++) {
                aux = {i < n_points ? -points[i] : 0, 1};
                fft(aux);
                segtree[0][i << 1] = aux[0];
                segtree[0][i << 1 | 1] = aux[1];
            }

            for (int len = 0; len < l; len++) {
                aux.resize(1 << (len + 1));
                for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
                    for (int j = 0; j < static_cast<int>(aux.size()); j++) {
                        aux[j] = segtree[len][i + j] * segtree[len][i + j + aux.size()];
                    }
                    if (len + 1 < l) {
                        std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i);
                        inv_fft(aux);
                        aux[0] -= 2;
                        right_half_fft(aux);
                        std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i + aux.size());
                    } else {
                        inv_fft(aux);
                        aux[0]--;
                        std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i);
                        segtree[len + 1][1 << (len + 1)]++;
                    }
                }
            }
        }

        std::vector<mint> evaluate(polynom_t<mint> f) const {
            if (static_cast<int>(f.size()) > (1 << l)) {
                f %= segtree[l];
            }
            assert(static_cast<int>(f.size()) <= (1 << l));
            f.resize(1 << (l + 1));
            std::rotate(f.begin(), std::prev(f.end()), f.end());
            fft(f);

            auto g = segtree[l];
            std::reverse(g.begin(), g.begin() + 1 + (1 << l));
            g.resize(1 << l);
            g = g.inv(1 << l);
            std::reverse(g.begin(), g.end());
            g.resize(1 << (l + 1));
            fft(g);
            for (int i = 0; i < (1 << (l + 1)); i++) {
                g[i] *= f[i];
            }

            polynom_t<mint> aux;
            aux.reserve(1 << l);
            for (int len = l - 1; len >= 0; len--) {
                aux.resize(1 << (len + 1));
                for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
                    for (int j = 0; j < static_cast<int>(aux.size()); j++) {
                        aux[j] = g[i + j + (1 << (len + 1))];
                    }
                    inv_right_half_fft(aux);
                    fft(aux);
                    std::copy(aux.begin(), aux.end(), g.begin() + i + (1 << (len + 1)));
                    for (int j = 0; j < (1 << (len + 1)); j++) {
                        auto x = g[i + j] - g[i + j + (1 << (len + 1))];
                        g[i + j + (1 << (len + 1))] = x * segtree[len][i + j];
                        g[i + j] = x * segtree[len][i + j + (1 << (len + 1))];
                    }
                }
            }

            std::vector<mint> eval(n_points);
            for (int i = 0; i < n_points; i++) {
                eval[i] = (g[2 * i] - g[2 * i + 1]) * fft_data.inv_l[l + 1];
            }
            return eval;
        }

        polynom_t<mint> interpolate(const std::vector<mint> &y) const {
            auto prod = segtree[l];
            prod.erase(prod.begin(), prod.begin() + ((1 << l) - n_points));
            auto values = evaluate(prod.derivative());

            polynom_t<mint> aux;
            aux.reserve(1 << l);
            polynom_t<mint> result(1 << (l + 1));
            for (int i = 0; i < (1 << l); i++) {
                aux = {i < n_points ? y[i] / values[i] : 0, 0};
                fft(aux);
                result[i << 1] = aux[0];
                result[i << 1 | 1] = aux[1];
            }

            polynom_t<mint> next(1 << (l + 1));
            for (int len = 0; len < l; len++) {
                aux.resize(1 << (len + 1));
                for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
                    for (int j = 0; j < (1 << (len + 1)); j++) {
                        aux[j] = segtree[len][i + j] * result[i + j + aux.size()]
                               + segtree[len][i + j + aux.size()] * result[i + j];
                    }
                    if (len + 1 < l) {
                        std::copy(aux.begin(), aux.end(), next.begin() + i);
                        inv_fft(aux);
                        right_half_fft(aux);
                        std::copy(aux.begin(), aux.end(), next.begin() + i + (1 << (len + 1)));
                    } else {
                        inv_fft(aux);
                        std::copy(aux.begin(), aux.end(), next.begin());
                    }
                }
                result.swap(next);
            }
            result.erase(result.begin(), result.begin() + ((1 << l) - n_points));
            return result.resize(n_points);
        }
    };

public:
    polynom_t() : std::vector<mint>() {}
    polynom_t(size_t n, mint value = 0) : std::vector<mint>(n, value) {}
    polynom_t(std::initializer_list<mint> values) : std::vector<mint>(values.begin(), values.end()) {}

    template<typename Iterator>
    polynom_t(Iterator first, Iterator last) : std::vector<mint>(first, last) {}

    polynom_t<mint>& resize(int n) {
        std::vector<mint>::resize(n);
        return *this;
    }

    // Removes extra zeroes.
    void normalize() {
        while (!empty() && back() == mint(0)) {
            pop_back();
        }
    }

    // Returns -1 if all coefficients are zeroes (not O(1)!).
    int degree() const {
        int deg = static_cast<int>(size()) - 1;
        while (deg >= 0 && (*this)[deg] == mint(0)) {
            deg--;
        }
        return deg;
    }

    mint eval(const mint &x) const {
        mint power = 1, value = 0;
        for (int i = 0; i < static_cast<int>(size()); i++, power *= x) {
            value += (*this)[i] * power;
        }
        return value;
    }

    // Calculates eval at the given points.
    // O(n log^2).
    std::vector<mint> multipoint_evaluation(const std::vector<mint> &points) const {
        if (std::min(size(), points.size()) <= EVAL_BRUTE_SIZE) {
            std::vector<mint> eval(points.size());
            for (int i = 0; i < static_cast<int>(points.size()); i++) {
                eval[i] = this->eval(points[i]);
            }
            return eval;
        }
        return multipoint_evaluation_tree(std::max(size(), points.size()), points).evaluate(*this);
    }

    // Interpolates polynomial f, such that f(x[i]) = y[i].
    // O(n log^2).
    static polynom_t<mint> interpolate(const std::vector<mint> &x, const std::vector<mint> &y) {
        assert(x.size() == y.size());
        return multipoint_evaluation_tree(x.size(), x).interpolate(y);
    }

    polynom_t<mint> operator-() const {
        return polynom_t(*this) * mint(-1);
    }

    polynom_t<mint>& operator+=(const polynom_t<mint> &another) {
        if (size() < another.size()) {
            resize(another.size());
        }
        for (int i = 0; i < static_cast<int>(another.size()); i++) {
            (*this)[i] += another[i];
        }
        return *this;
    }

    polynom_t<mint>& operator-=(const polynom_t<mint> &another) {
        if (size() < another.size()) {
            resize(another.size());
        }
        for (int i = 0; i < static_cast<int>(another.size()); i++) {
            (*this)[i] -= another[i];
        }
        return *this;
    }

    polynom_t<mint>& operator*=(const polynom_t<mint> &another) {
        if (empty() || another.empty()) {
            clear();
            return *this;
        }

        if (std::min(size(), another.size()) <= MUL_MIN_CUT
            || std::max(size(), another.size()) <= MUL_MAX_CUT) {
            polynom_t<mint> product(int(size() + another.size()) - 1);
            for (int i = 0; i < static_cast<int>(size()); i++) {
                for (int j = 0; j < static_cast<int>(another.size()); j++) {
                    product[i + j] += (*this)[i] * another[j];
                }
            }
            product.swap(*this);
            return *this;
        }

        int real_size = static_cast<int>(size() + another.size()) - 1, n = 1;
        while (n < real_size) {
            n <<= 1;
        }

        if ((*this) == another) {
            resize(n);
            fft(*this);
            for (int i = 0; i < n; i++) {
                (*this)[i] *= (*this)[i];
            }
        } else {
            resize(n);
            polynom_t b = another;
            b.resize(n);
            fft(*this), fft(b);
            for (int i = 0; i < n; i++) {
                (*this)[i] *= b[i];
            }
        }
        inv_fft(*this);
        return resize(real_size);
    }

    // Division with remainder.
    // O(n log).
    polynom_t<mint>& operator/=(const polynom_t<mint> &another) {
        polynom_t<mint> a(*this), b(another);
        a.normalize(), b.normalize();
        assert(!b.empty());
        int n = a.size(), m = b.size();
        if (n < m) {
            return *this = {};
        }
        if (n <= DIV_N_CUT || m <= DIV_M_CUT) {
            polynom_t<mint> quotient(n - m + 1);
            mint inv_b = mint(1) / b.back();
            for (int i = n - 1; i >= m - 1; i--) {
                int pos = i - m + 1;
                quotient[pos] = a[i] * inv_b;
                for (int j = 0; j < m; j++) {
                    a[pos + j] -= b[j] * quotient[pos];
                }
            }
            quotient.normalize();
            return *this = quotient;
        }

        std::reverse(a.begin(), a.end());
        std::reverse(b.begin(), b.end());
        polynom_t<mint> quotient = a * b.inv(n - m + 1);
        quotient.resize(n - m + 1);
        std::reverse(quotient.begin(), quotient.end());
        quotient.normalize();
        return *this = quotient;
    }

    // O(n log).
    polynom_t<mint>& operator%=(const polynom_t<mint> &another) {
        *this -= (*this) / another * another;
        normalize();
        return *this;
    }

    // Returns derivative.
    // O(n).
    polynom_t<mint> derivative() const {
        polynom_t<mint> der(std::max(0, static_cast<int>(size()) - 1));
        for (int i = 0; i < static_cast<int>(der.size()); i++) {
            der[i] = mint(i + 1) * (*this)[i + 1];
        }
        return der;
    }

    // Returns integral.
    // O(n).
    polynom_t<mint> integral(const mint &constant = mint(0)) const {
        polynom_t<mint> in(size() + 1);
        in[0] = constant;
        for (int i = 1; i < static_cast<int>(in.size()); i++) {
            in[i] = (*this)[i - 1] * fft_data.inv(i);
        }
        return in;
    }

    // Returns p^{-1} modulo x^{degree}.
    // O(n log).
    polynom_t<mint> inv(int degree) const {
        assert(!empty() && (*this)[0] != mint(0) && "polynom is not invertable");

        int result_size = 1;
        while (result_size < degree) {
            result_size <<= 1;
        }
        polynom_t<mint> inv(result_size);
        inv[0] = (*this)[0].inv();

        int brute_calculation = std::min(degree, INV_BRUTE_FORCE_SIZE);
        mint start_inv = (*this)[0].inv();

        polynom_t<mint> have(brute_calculation);
        for (int i = 0; i < brute_calculation; i++) {
            inv[i] = ((i == 0 ? 1 : 0) - have[i]) * start_inv;
            int steps = std::min<int>(size(), have.size() - i);
            for (int j = 0; j < steps; j++) {
                have[i + j] += inv[i] * (*this)[j];
            }
        }

        polynom_t<mint> this_copy, inv_copy;
        this_copy.reserve(result_size);
        inv_copy.reserve(result_size);

        for (int power = brute_calculation; power < degree; power <<= 1) {
            this_copy.resize(power << 1);
            std::fill(this_copy.begin() + std::min<int>(size(), power << 1), this_copy.end(), 0);
            std::copy(begin(), begin() + std::min<int>(size(), power << 1), this_copy.begin());
            inv_copy.resize(power << 1);
            std::copy(inv.begin(), inv.begin() + power, inv_copy.begin());

            fft(this_copy);
            fft(inv_copy);
            for (int i = 0; i < (power << 1); i++) {
                this_copy[i] *= inv_copy[i];
            }
            inv_fft(this_copy);
            std::fill(this_copy.begin(), this_copy.begin() + power, 0);
            fft(this_copy);
            for (int i = 0; i < (power << 1); i++) {
                this_copy[i] *= inv_copy[i];
            }
            inv_fft(this_copy);
            for (int i = power; i < (power << 1); i++) {
                inv[i] = -this_copy[i];
            }
        }
        return inv.resize(degree);
    }

    // Returns log(p) modulo x^{degree}.
    // O(n log).
    polynom_t<mint> log(int degree) const {
        assert(!empty() && (*this)[0] == mint(1) && "log is not defined");
        return (derivative().resize(std::min<int>(degree, size())) * inv(degree)).resize(degree).integral(mint(0)).resize(degree);
    }

    // Returns exp(p) modulo x^{degree}.
    // O(n log).
    polynom_t<mint> exp(int degree) const {
        assert(!empty() && (*this)[0] == mint(0) && "exp is not defined");

        int result_size = 1;
        while (result_size < degree) {
            result_size <<= 1;
        }

        polynom_t<mint> exp, exp_fft, inv_exp, inv_exp_fft;
        exp = exp_fft = inv_exp = inv_exp_fft = {1};
        polynom_t<mint> h(1), q(1);
        auto diff = derivative().resize(result_size);

        for (int power = 1; power < degree; power <<= 1) {
            exp_fft = polynom_t<mint>(exp).resize(power << 1);
            fft(exp_fft);
            for (int i = 0; i < power; i++) {
                h[i] = exp_fft[i] * inv_exp_fft[i];
            }
            inv_fft(h);

            std::fill(h.begin(), h.begin() + ((power + 1) >> 1), 0);
            fft(h);
            for (int i = 0; i < power; i++) {
                h[i] *= -inv_exp_fft[i];
            }
            inv_fft(h);

            for (int i = (power + 1) / 2; i < power; i++) {
                inv_exp.push_back(h[i]);
            }
            inv_exp_fft = polynom_t<mint>(inv_exp).resize(power << 1);
            fft(inv_exp_fft);

            h.assign(power, 0);
            std::copy(diff.begin(), diff.begin() + power - 1, h.begin());
            fft(h);
            for (int i = 0; i < power; i++) {
                q[i] = exp_fft[i] * h[i];
            }
            inv_fft(q);

            h.resize(power << 1);
            for (int i = 1; i < power; i++) {
                h[i] = exp[i] * i;
            }
            for (int i = 0; i < power; i++) {
                h[i + 1] -= q[i];
            }
            h[0] = -q[power - 1];
            fft(h);

            q.resize(power << 1);
            for (int i = 0; i < (power << 1); i++) {
                q[i] = inv_exp_fft[i] * h[i];
            }
            inv_fft(q);

            if (static_cast<int>(size()) >= power) {
                h.assign(begin() + power, begin() + std::min<int>(power << 1, size()));
            } else {
                h.clear();
            }
            h.resize(power);

            for (int i = power - 1; i >= 0; i--) {
                h[i] -= q[i] * fft_data.inv(i + power);
            }
            h.resize(power << 1);
            fft(h);
            for (int i = 0; i < (power << 1); i++) {
                q[i] = exp_fft[i] * h[i];
            }
            inv_fft(q);

            exp.resize(power << 1);
            for (int i = 0; i < power; i++) {
                exp[power + i] = q[i];
            }
        }
        return exp.resize(degree);
    }

    // Returns p^{d} modulo x^{degree}.
    // O(n log).
    polynom_t<mint> power(int64_t d, int degree) const {
        if (!d || !degree) {
            return polynom_t<mint>{1}.resize(degree);
        }
        int pos = 0;
        while (pos < static_cast<int>(size()) && (*this)[pos] == mint(0)) {
            pos++;
        }
        if (pos == static_cast<int>(size()) || pos >= (degree + d - 1) / d) {
            return polynom_t<mint>(degree);
        }

        int coeffs_left = degree - d * pos;
        polynom_t<mint> result = ((polynom_t<mint>(begin() + pos, end()) / (*this)[pos]).log(coeffs_left)
                               * mint(d)).exp(coeffs_left) * (*this)[pos].power(d);
        result.resize(degree);
        for (int i = degree - 1; i - (degree - coeffs_left) >= 0; i--) {
            result[i] = result[i - (degree - coeffs_left)];
        }
        std::fill(result.begin(), result.end() - coeffs_left, mint(0));
        return result;
    }

    // Returns f(x + c).
    // O(n log).
    polynom_t<mint> change_of_variable(mint c) const {
        int n = size();
        polynom_t<mint> p(n), q(n);
        mint fact = 1, ifact = 1, power = 1;
        for (int i = 0; i < n; i++) {
            p[n - 1 - i] = (*this)[i] * fact;
            q[i] = power * ifact;
            fact *= i + 1;
            ifact *= fft_data.inv(i + 1);
            power *= c;
        }

        auto result = p * q;
        result.resize(n);
        std::reverse(result.begin(), result.end());
        ifact = 1;
        for (int i = 0; i < n; i++) {
            result[i] *= ifact;
            ifact *= fft_data.inv(i + 1);
        }
        return result;
    }

    template<typename V>
    std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>&> operator*=(const V &value) {
        for (auto &x : *this) {
            x *= value;
        }
        return *this;
    }

    template<typename V>
    std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>&> operator/=(const V &value) {
        for (auto &x : *this) {
            x /= value;
        }
        return *this;
    }

    friend std::ostream& operator<<(std::ostream &out, const polynom_t<mint> &p) {
        for (int i = 0; i < static_cast<int>(p.size()); i++) {
            if (i) out << ' ';
            out << p[i];
        }
        return out;
    }

    friend polynom_t<mint> operator+(const polynom_t<mint> &p, const polynom_t<mint> &q) {
        return polynom_t(p) += q;
    }

    friend polynom_t<mint> operator-(const polynom_t<mint> &p, const polynom_t<mint> &q) {
        return polynom_t(p) -= q;
    }

    friend polynom_t<mint> operator*(const polynom_t<mint> &p, const polynom_t<mint> &q) {
        return polynom_t(p) *= q;
    }

    friend polynom_t<mint> operator/(const polynom_t<mint> &p, const polynom_t<mint> &q) {
        return polynom_t(p) /= q;
    }

    friend polynom_t<mint> operator%(const polynom_t<mint> &p, const polynom_t<mint> &q) {
        return polynom_t(p) %= q;
    }

    template<typename V>
    friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
        operator*(const V &value, const polynom_t<mint> &p) {
        return polynom_t<mint>(p) *= value;
    }

    template<typename V>
    friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
        operator*(const polynom_t<mint> &p, const V &value) {
        return polynom_t<mint>(p) *= value;
    }

    template<typename V>
    friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
        operator/(const V &value, const polynom_t<mint> &p) {
        return polynom_t<mint>(p) /= value;
    }

    template<typename V>
    friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
        operator/(const polynom_t<mint> &p, const V &value) {
        return polynom_t<mint>(p) /= value;
    }
};

using polynom = polynom_t<mint>;

/*
 ! WARNING: MOD must be prime.
 * Define modular int class above it.
 * No need to run any init function, it dynamically resizes the data.
 */
namespace combinatorics {
    std::vector<mint> fact_, ifact_, inv_;

    void resize_data(int size) {
        if (fact_.empty()) {
            fact_ = {mint(1), mint(1)};
            ifact_ = {mint(1), mint(1)};
            inv_ = {mint(0), mint(1)};
        }
        for (int pos = fact_.size(); pos <= size; pos++) {
            fact_.push_back(fact_.back() * mint(pos));
            inv_.push_back(-inv_[MOD % pos] * mint(MOD / pos));
            ifact_.push_back(ifact_.back() * inv_[pos]);
        }
    }

    struct combinatorics_info {
        std::vector<mint> &data;

        combinatorics_info(std::vector<mint> &data) : data(data) {}

        mint operator[](int pos) {
            if (pos >= static_cast<int>(data.size())) {
                resize_data(pos);
            }
            return data[pos];
        }
    } fact(fact_), ifact(ifact_), inv(inv_);

    // From n choose k.
    // O(max(n)) in total.
    mint choose(int n, int k) {
        if (n < k || k < 0 || n < 0) {
            return mint(0);
        }
        return fact[n] * ifact[k] * ifact[n - k];
    }

    // From n choose k.
    // O(min(k, n - k)).
    mint choose_slow(int64_t n, int64_t k) {
        if (n < k || k < 0 || n < 0) {
            return mint(0);
        }
        k = std::min(k, n - k);
        mint result = 1;
        for (int i = k; i >= 1; i--) {
            result *= (n - i + 1);
            result *= inv[i];
        }
        return result;
    }

    // Number of balanced bracket sequences with n open and m closing brackets.
    mint catalan(int n, int m) {
        if (m > n || m < 0) {
            return mint(0);
        }
        return choose(n + m, m) - choose(n + m, m - 1);
    }

    // Number of balanced bracket sequences with n open and closing brackets.
    mint catalan(int n) {
        return catalan(n, n);
    }
} // namespace combinatorics

using namespace combinatorics;

mint solve(int n, int m) {
    return choose(n + m - 2, n - 1) * choose(n + m - 2, n - 1) - choose(n + m - 2, n) * choose(n + m - 2, m);
}

mint solve_pairs(int n, int m) {
    mint ans = 0;
    // for (int y = 0; y < n; y++) {
    //     for (int y2 = y + 1; y2 <= n; y2++) {
    //         ans += solve(y2 - y, m);
    //     }
    // }
    for (int d = 1; d <= n; d++) {
        ans += (n - d + 1) * solve(d, m);
    }
    return ans;
}

mint solve_borders(int n, int m) {
    mint ans = 0;
    // for (int y = 1; y <= n; y++) {
    //     for (int x = 1; x < m; x++) {
    //         ans += choose(x + y - 2, x - 1) * choose(x + y - 2, x - 1);
    //         ans -= choose(x + y - 2, x) * choose(x + y - 2, y);
    //     }
    // }
    {
        polynom p(n + 1);
        for (int y = 1; y <= n; y++) {
            p[y] = ifact[y - 1] * ifact[y - 1];
        }
        polynom q(m);
        for (int x = 1; x < m; x++) {
            q[x] = ifact[x - 1] * ifact[x - 1];
        }
        auto f = p * q;
        for (int s = 2; s < len(f); s++) {
            ans += f[s] * fact[s - 2] * fact[s - 2];
        }
    }
    {
        polynom p(n + 1);
        for (int y = 1; y <= n; y++) {
            p[y] = ifact[y - 2] * ifact[y];
        }
        polynom q(m);
        for (int x = 1; x < m; x++) {
            q[x]= ifact[x - 2] * ifact[x];
        }
        auto f = p * q;
        for (int s = 2; s < len(f); s++) {
            ans -= f[s] * fact[s - 2] * fact[s - 2];
        }
    }
    return ans;
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);

    int n, m;
    cin >> n >> m;

    mint ans = solve_pairs(n, m) + solve_pairs(m, n) + solve_borders(n, m) + solve_borders(m, n);
    // for (int y = 0; y < n; y++) {
    //     for (int x = 1; x < m; x++) {
    //         ans += solve(x, n - y);
    //     }
    //     for (int y2 = y + 1; y2 <= n; y2++) {
    //         ans += solve(m, y2 - y);
    //     }
    // }

    // for (int x = 1; x < m; x++) {
    //     for (int y = 1; y < n; y++) {
    //         ans += solve(m - x, y);
    //     }
    //     for (int x2 = x + 1; x2 <= m; x2++) {
    //         ans += solve(x2 - x, n);
    //     }
    // }

    for (int y = 1; y < n; y++) {
        ans -= solve(m, y);
    }
    for (int x2 = 1; x2 <= m; x2++) {
        ans -= solve(x2, n);
    }

    cout << ans + 2 << '\n';
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 0ms
memory: 3540kb

input:

2 2

output:

11

result:

ok "11"

Test #2:

score: 0
Accepted
time: 0ms
memory: 3736kb

input:

20 23

output:

521442928

result:

ok "521442928"

Test #3:

score: 0
Accepted
time: 133ms
memory: 15288kb

input:

200000 200000

output:

411160917

result:

ok "411160917"

Test #4:

score: 0
Accepted
time: 0ms
memory: 3584kb

input:

8 3

output:

2899

result:

ok "2899"

Test #5:

score: 0
Accepted
time: 0ms
memory: 3568kb

input:

10 9

output:

338037463

result:

ok "338037463"

Test #6:

score: 0
Accepted
time: 0ms
memory: 3528kb

input:

3 3

output:

64

result:

ok "64"

Test #7:

score: 0
Accepted
time: 0ms
memory: 3792kb

input:

9 4

output:

39733

result:

ok "39733"

Test #8:

score: 0
Accepted
time: 0ms
memory: 3536kb

input:

36 33

output:

545587245

result:

ok "545587245"

Test #9:

score: 0
Accepted
time: 0ms
memory: 3744kb

input:

35 39

output:

62117944

result:

ok "62117944"

Test #10:

score: 0
Accepted
time: 0ms
memory: 3536kb

input:

48 10

output:

264659761

result:

ok "264659761"

Test #11:

score: 0
Accepted
time: 0ms
memory: 3572kb

input:

46 30

output:

880000821

result:

ok "880000821"

Test #12:

score: 0
Accepted
time: 0ms
memory: 3536kb

input:

25 24

output:

280799864

result:

ok "280799864"

Test #13:

score: 0
Accepted
time: 0ms
memory: 3584kb

input:

17 10

output:

624958192

result:

ok "624958192"

Test #14:

score: 0
Accepted
time: 4ms
memory: 3712kb

input:

4608 9241

output:

322218996

result:

ok "322218996"

Test #15:

score: 0
Accepted
time: 4ms
memory: 3724kb

input:

3665 6137

output:

537704652

result:

ok "537704652"

Test #16:

score: 0
Accepted
time: 0ms
memory: 3752kb

input:

4192 6186

output:

971816887

result:

ok "971816887"

Test #17:

score: 0
Accepted
time: 4ms
memory: 3660kb

input:

4562 4403

output:

867628411

result:

ok "867628411"

Test #18:

score: 0
Accepted
time: 4ms
memory: 3756kb

input:

8726 3237

output:

808804305

result:

ok "808804305"

Test #19:

score: 0
Accepted
time: 4ms
memory: 3736kb

input:

5257 8166

output:

488829288

result:

ok "488829288"

Test #20:

score: 0
Accepted
time: 4ms
memory: 3980kb

input:

8013 7958

output:

215666893

result:

ok "215666893"

Test #21:

score: 0
Accepted
time: 4ms
memory: 3896kb

input:

8837 5868

output:

239261227

result:

ok "239261227"

Test #22:

score: 0
Accepted
time: 4ms
memory: 3768kb

input:

8917 5492

output:

706653412

result:

ok "706653412"

Test #23:

score: 0
Accepted
time: 4ms
memory: 3664kb

input:

9628 5378

output:

753685501

result:

ok "753685501"

Test #24:

score: 0
Accepted
time: 122ms
memory: 14200kb

input:

163762 183794

output:

141157510

result:

ok "141157510"

Test #25:

score: 0
Accepted
time: 58ms
memory: 8464kb

input:

83512 82743

output:

114622013

result:

ok "114622013"

Test #26:

score: 0
Accepted
time: 54ms
memory: 8072kb

input:

84692 56473

output:

263907717

result:

ok "263907717"

Test #27:

score: 0
Accepted
time: 33ms
memory: 6140kb

input:

31827 74195

output:

200356808

result:

ok "200356808"

Test #28:

score: 0
Accepted
time: 137ms
memory: 14268kb

input:

189921 163932

output:

845151158

result:

ok "845151158"

Test #29:

score: 0
Accepted
time: 69ms
memory: 9280kb

input:

27157 177990

output:

847356039

result:

ok "847356039"

Test #30:

score: 0
Accepted
time: 58ms
memory: 8652kb

input:

136835 39390

output:

962822638

result:

ok "962822638"

Test #31:

score: 0
Accepted
time: 64ms
memory: 7964kb

input:

118610 18795

output:

243423874

result:

ok "243423874"

Test #32:

score: 0
Accepted
time: 60ms
memory: 8084kb

input:

122070 19995

output:

531055604

result:

ok "531055604"

Test #33:

score: 0
Accepted
time: 70ms
memory: 9528kb

input:

20031 195670

output:

483162363

result:

ok "483162363"

Test #34:

score: 0
Accepted
time: 135ms
memory: 15312kb

input:

199992 199992

output:

262099623

result:

ok "262099623"

Test #35:

score: 0
Accepted
time: 134ms
memory: 15292kb

input:

200000 199992

output:

477266520

result:

ok "477266520"

Test #36:

score: 0
Accepted
time: 135ms
memory: 15300kb

input:

199999 199996

output:

165483205

result:

ok "165483205"

Test #37:

score: 0
Accepted
time: 0ms
memory: 3608kb

input:

1 1

output:

3

result:

ok "3"

Test #38:

score: 0
Accepted
time: 6ms
memory: 5596kb

input:

1 100000

output:

8828237

result:

ok "8828237"

Test #39:

score: 0
Accepted
time: 3ms
memory: 5572kb

input:

100000 2

output:

263711286

result:

ok "263711286"

Test #40:

score: 0
Accepted
time: 0ms
memory: 3540kb

input:

50 50

output:

634767411

result:

ok "634767411"

Extra Test:

score: 0
Extra Test Passed