QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#659138 | #9476. 012 Grid | ucup-team4435# | AC ✓ | 137ms | 15312kb | C++20 | 35.6kb | 2024-10-19 18:53:56 | 2024-10-19 18:53:57 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
#define all(a) begin(a), end(a)
#define len(a) int((a).size())
/*
! WARNING: MOD must be prime.
! WARNING: MOD must be less than 2^30.
* Use .get() to get the stored value.
*/
template<uint32_t mod>
class montgomery {
static_assert(mod < uint32_t(1) << 30, "mod < 2^30");
using mint = montgomery<mod>;
private:
uint32_t value;
static constexpr uint32_t inv_neg_mod = []() {
uint32_t x = mod;
for (int i = 0; i < 4; ++i) {
x *= uint32_t(2) - mod * x;
}
return -x;
}();
static_assert(mod * inv_neg_mod == -1);
static constexpr uint32_t neg_mod = (-uint64_t(mod)) % mod;
static uint32_t reduce(const uint64_t &value) {
return (value + uint64_t(uint32_t(value) * inv_neg_mod) * mod) >> 32;
}
inline static const mint ONE = mint(1);
public:
montgomery() : value(0) {}
montgomery(const mint &x) : value(x.value) {}
template<typename T, typename U = std::enable_if_t<std::is_integral<T>::value>>
montgomery(const T &x) : value(!x ? 0 : reduce(int64_t(x % int32_t(mod) + int32_t(mod)) * neg_mod)) {}
static constexpr uint32_t get_mod() {
return mod;
}
uint32_t get() const {
auto real_value = reduce(value);
return real_value < mod ? real_value : real_value - mod;
}
template<typename T>
mint power(T degree) const {
degree = (degree % int32_t(mod - 1) + int32_t(mod - 1)) % int32_t(mod - 1);
mint prod = 1, a = *this;
for (; degree > 0; degree >>= 1, a *= a)
if (degree & 1)
prod *= a;
return prod;
}
mint inv() const {
return power(-1);
}
mint& operator=(const mint &x) {
value = x.value;
return *this;
}
mint& operator+=(const mint &x) {
if (int32_t(value += x.value - (mod << 1)) < 0) {
value += mod << 1;
}
return *this;
}
mint& operator-=(const mint &x) {
if (int32_t(value -= x.value) < 0) {
value += mod << 1;
}
return *this;
}
mint& operator*=(const mint &x) {
value = reduce(uint64_t(value) * x.value);
return *this;
}
mint& operator/=(const mint &x) {
return *this *= x.inv();
}
friend mint operator+(const mint &x, const mint &y) {
return mint(x) += y;
}
friend mint operator-(const mint &x, const mint &y) {
return mint(x) -= y;
}
friend mint operator*(const mint &x, const mint &y) {
return mint(x) *= y;
}
friend mint operator/(const mint &x, const mint &y) {
return mint(x) /= y;
}
mint& operator++() {
return *this += ONE;
}
mint& operator--() {
return *this -= ONE;
}
mint operator++(int) {
mint prev = *this;
*this += ONE;
return prev;
}
mint operator--(int) {
mint prev = *this;
*this -= ONE;
return prev;
}
mint operator-() const {
return mint(0) - *this;
}
bool operator==(const mint &x) const {
return get() == x.get();
}
bool operator!=(const mint &x) const {
return get() != x.get();
}
bool operator<(const mint &x) const {
return get() < x.get();
}
template<typename T>
explicit operator T() {
return get();
}
friend std::istream& operator>>(std::istream &in, mint &x) {
std::string s;
in >> s;
x = 0;
bool neg = s[0] == '-';
for (const auto c : s)
if (c != '-')
x = x * 10 + (c - '0');
if (neg)
x *= -1;
return in;
}
friend std::ostream& operator<<(std::ostream &out, const mint &x) {
return out << x.get();
}
static int32_t primitive_root() {
if constexpr (mod == 1'000'000'007)
return 5;
if constexpr (mod == 998'244'353)
return 3;
if constexpr (mod == 786433)
return 10;
static int root = -1;
if (root != -1)
return root;
std::vector<int> primes;
int value = mod - 1;
for (int i = 2; i * i <= value; i++)
if (value % i == 0) {
primes.push_back(i);
while (value % i == 0)
value /= i;
}
if (value != 1)
primes.push_back(value);
for (int r = 2;; r++) {
bool ok = true;
for (auto p : primes)
if ((mint(r).power((mod - 1) / p)).get() == 1) {
ok = false;
break;
}
if (ok)
return root = r;
}
}
};
// constexpr uint32_t MOD = 1'000'000'007;
constexpr uint32_t MOD = 998'244'353;
using mint = montgomery<MOD>;
/*
! WARNING: (MOD - 1) must be divisible by (2n), where n is max degree of polynomials.
* Include static_modular_int or montgomery (faster) to use it.
* Don't need to care about precomputing primitive root.
* Use it like std::vector<mint> with extra methods.
*/
template<typename mint>
class polynom_t : public std::vector<mint> {
public:
static constexpr int RANK = __builtin_ctz(mint::get_mod() - 1);
static_assert(RANK >= 15, "MOD doesn't seem fft-friendly.");
using value_type = mint;
using std::vector<mint>::empty;
using std::vector<mint>::back;
using std::vector<mint>::pop_back;
using std::vector<mint>::size;
using std::vector<mint>::clear;
using std::vector<mint>::begin;
using std::vector<mint>::end;
private:
static constexpr int EVAL_BRUTE_SIZE = 1 << 4;
static constexpr int MUL_MIN_CUT = 20;
static constexpr int MUL_MAX_CUT = 1 << 6;
static constexpr int DIV_N_CUT = 1 << 7;
static constexpr int DIV_M_CUT = 1 << 6;
static constexpr int INV_BRUTE_FORCE_SIZE = 1 << 3;
class fft_precalc {
private:
std::vector<mint> inv_{0, 1};
public:
mint root;
// roots[x] = root.power((MOD - 1) / 2^x)
// inv_roots[x] = roots[x].inv()
std::array<mint, RANK> roots, inv_roots;
// inv_l[x] = 2^{-x}
std::array<mint, RANK> inv_l;
// r_transition[x] = root.power((MOD - 1) * (3 - 2^{x + 1}) / 2^{x + 2})
// inv_r_transition[x] = r_transition[x].inv()
std::array<mint, RANK> r_transition, inv_r_transition;
fft_precalc() {
root = mint::primitive_root();
for (int x = 0; x < RANK; x++) {
roots[x] = mint(root).power((mint::get_mod() - 1) >> x);
inv_roots[x] = roots[x].inv();
inv_l[x] = mint(1 << x).inv();
r_transition[x] = root.power(((mint::get_mod() - 1) >> (x + 2)) * (3 + (1 << (x + 1))));
inv_r_transition[x] = r_transition[x].inv();
}
}
mint inv(int x) {
for (int i = inv_.size(); i <= x; i++) {
inv_.push_back(-inv_[mint::get_mod() % i] * mint(mint::get_mod() / i));
}
return inv_[x];
}
};
inline static fft_precalc fft_data;
public:
/*
* Let r be the primitive root of the given MOD.
* Let rev(i) be the reversed i as a binary mask of size log(n).
* Let R[i] = r.power((MOD - 1) / a.size() * rev(i)).
* It replaces a[i] with a.eval(R[i]).
*/
static void fft(polynom_t<mint> &a) {
assert(!a.empty());
int n = a.size(), l = __builtin_ctz(n);
for (int len = 0; len < l; len++) {
int m = n >> (len + 1);
mint r = 1;
for (int i = 0; i < (1 << len); i++) {
// rot = r^((MOD - 1) / n * rev(2 * i))
int id = i << (l - len);
for (int j = 0; j < m; j++) {
auto u = a[id + j];
auto v = a[id + j + m] * r;
a[id + j] = u + v;
a[id + j + m] = u - v;
}
if (i + 1 < (1 << len)) {
r *= fft_data.r_transition[__builtin_ctz(~i)];
}
}
}
}
// inv_fft(fft(a)) = a
static void inv_fft(polynom_t<mint> &a) {
assert(!a.empty());
int n = a.size(), l = __builtin_ctz(n);
for (int len = l - 1; len >= 0; len--) {
int m = n >> (len + 1);
mint ir = 1;
for (int i = 0; i < (1 << len); i++) {
int id = i << (l - len);
for (int j = 0; j < m; j++) {
auto u = a[id + j];
auto v = a[id + j + m];
a[id + j] = u + v;
a[id + j + m] = (u - v) * ir;
}
if (i + 1 < (1 << len)) {
ir *= fft_data.inv_r_transition[__builtin_ctz(~i)];
}
}
}
for (auto &value : a) {
value *= fft_data.inv_l[l];
}
}
private:
// Completes fft assuming that a is the right half of the polynomial after the first iteration of the fft.
static void right_half_fft(polynom_t<mint> &a) {
assert(!a.empty());
int n = a.size(), l = __builtin_ctz(n);
for (int len = 0; len < l; len++) {
int m = n >> (len + 1);
mint r = fft_data.roots[len + 2];
for (int i = 0; i < (1 << len); i++) {
int id = i << (l - len);
for (int j = 0; j < m; j++) {
auto u = a[id + j];
auto v = a[id + j + m] * r;
a[id + j] = u + v;
a[id + j + m] = u - v;
}
if (i + 1 < (1 << len)) {
r *= fft_data.r_transition[__builtin_ctz(~i)];
}
}
}
}
// inv_right_half_fft(right_half_fft(a)) = a
static void inv_right_half_fft(polynom_t<mint> &a) {
assert(!a.empty());
int n = a.size(), l = __builtin_ctz(n);
for (int len = l - 1; len >= 0; len--) {
int m = n >> (len + 1);
mint ir = fft_data.inv_roots[len + 2];
for (int i = 0; i < (1 << len); i++) {
int id = i << (l - len);
for (int j = 0; j < m; j++) {
auto u = a[id + j];
auto v = a[id + j + m];
a[id + j] = u + v;
a[id + j + m] = (u - v) * ir;
}
if (i + 1 < (1 << len)) {
ir *= fft_data.inv_r_transition[__builtin_ctz(~i)];
}
}
}
for (auto &value : a) {
value *= fft_data.inv_l[l];
}
}
class multipoint_evaluation_tree {
private:
int n_points, l;
std::vector<polynom_t<mint>> segtree;
public:
multipoint_evaluation_tree(int n, const std::vector<mint> &points) : n_points(points.size()), l(1) {
while ((1 << l) < n) {
l++;
}
polynom_t<mint> aux;
aux.reserve(1 << l);
segtree.assign(l + 1, polynom_t<mint>(1 << (l + 1)));
for (int i = 0; i < (1 << l); i++) {
aux = {i < n_points ? -points[i] : 0, 1};
fft(aux);
segtree[0][i << 1] = aux[0];
segtree[0][i << 1 | 1] = aux[1];
}
for (int len = 0; len < l; len++) {
aux.resize(1 << (len + 1));
for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
for (int j = 0; j < static_cast<int>(aux.size()); j++) {
aux[j] = segtree[len][i + j] * segtree[len][i + j + aux.size()];
}
if (len + 1 < l) {
std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i);
inv_fft(aux);
aux[0] -= 2;
right_half_fft(aux);
std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i + aux.size());
} else {
inv_fft(aux);
aux[0]--;
std::copy(aux.begin(), aux.end(), segtree[len + 1].begin() + i);
segtree[len + 1][1 << (len + 1)]++;
}
}
}
}
std::vector<mint> evaluate(polynom_t<mint> f) const {
if (static_cast<int>(f.size()) > (1 << l)) {
f %= segtree[l];
}
assert(static_cast<int>(f.size()) <= (1 << l));
f.resize(1 << (l + 1));
std::rotate(f.begin(), std::prev(f.end()), f.end());
fft(f);
auto g = segtree[l];
std::reverse(g.begin(), g.begin() + 1 + (1 << l));
g.resize(1 << l);
g = g.inv(1 << l);
std::reverse(g.begin(), g.end());
g.resize(1 << (l + 1));
fft(g);
for (int i = 0; i < (1 << (l + 1)); i++) {
g[i] *= f[i];
}
polynom_t<mint> aux;
aux.reserve(1 << l);
for (int len = l - 1; len >= 0; len--) {
aux.resize(1 << (len + 1));
for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
for (int j = 0; j < static_cast<int>(aux.size()); j++) {
aux[j] = g[i + j + (1 << (len + 1))];
}
inv_right_half_fft(aux);
fft(aux);
std::copy(aux.begin(), aux.end(), g.begin() + i + (1 << (len + 1)));
for (int j = 0; j < (1 << (len + 1)); j++) {
auto x = g[i + j] - g[i + j + (1 << (len + 1))];
g[i + j + (1 << (len + 1))] = x * segtree[len][i + j];
g[i + j] = x * segtree[len][i + j + (1 << (len + 1))];
}
}
}
std::vector<mint> eval(n_points);
for (int i = 0; i < n_points; i++) {
eval[i] = (g[2 * i] - g[2 * i + 1]) * fft_data.inv_l[l + 1];
}
return eval;
}
polynom_t<mint> interpolate(const std::vector<mint> &y) const {
auto prod = segtree[l];
prod.erase(prod.begin(), prod.begin() + ((1 << l) - n_points));
auto values = evaluate(prod.derivative());
polynom_t<mint> aux;
aux.reserve(1 << l);
polynom_t<mint> result(1 << (l + 1));
for (int i = 0; i < (1 << l); i++) {
aux = {i < n_points ? y[i] / values[i] : 0, 0};
fft(aux);
result[i << 1] = aux[0];
result[i << 1 | 1] = aux[1];
}
polynom_t<mint> next(1 << (l + 1));
for (int len = 0; len < l; len++) {
aux.resize(1 << (len + 1));
for (int i = 0; i < (1 << (l + 1)); i += (1 << (len + 2))) {
for (int j = 0; j < (1 << (len + 1)); j++) {
aux[j] = segtree[len][i + j] * result[i + j + aux.size()]
+ segtree[len][i + j + aux.size()] * result[i + j];
}
if (len + 1 < l) {
std::copy(aux.begin(), aux.end(), next.begin() + i);
inv_fft(aux);
right_half_fft(aux);
std::copy(aux.begin(), aux.end(), next.begin() + i + (1 << (len + 1)));
} else {
inv_fft(aux);
std::copy(aux.begin(), aux.end(), next.begin());
}
}
result.swap(next);
}
result.erase(result.begin(), result.begin() + ((1 << l) - n_points));
return result.resize(n_points);
}
};
public:
polynom_t() : std::vector<mint>() {}
polynom_t(size_t n, mint value = 0) : std::vector<mint>(n, value) {}
polynom_t(std::initializer_list<mint> values) : std::vector<mint>(values.begin(), values.end()) {}
template<typename Iterator>
polynom_t(Iterator first, Iterator last) : std::vector<mint>(first, last) {}
polynom_t<mint>& resize(int n) {
std::vector<mint>::resize(n);
return *this;
}
// Removes extra zeroes.
void normalize() {
while (!empty() && back() == mint(0)) {
pop_back();
}
}
// Returns -1 if all coefficients are zeroes (not O(1)!).
int degree() const {
int deg = static_cast<int>(size()) - 1;
while (deg >= 0 && (*this)[deg] == mint(0)) {
deg--;
}
return deg;
}
mint eval(const mint &x) const {
mint power = 1, value = 0;
for (int i = 0; i < static_cast<int>(size()); i++, power *= x) {
value += (*this)[i] * power;
}
return value;
}
// Calculates eval at the given points.
// O(n log^2).
std::vector<mint> multipoint_evaluation(const std::vector<mint> &points) const {
if (std::min(size(), points.size()) <= EVAL_BRUTE_SIZE) {
std::vector<mint> eval(points.size());
for (int i = 0; i < static_cast<int>(points.size()); i++) {
eval[i] = this->eval(points[i]);
}
return eval;
}
return multipoint_evaluation_tree(std::max(size(), points.size()), points).evaluate(*this);
}
// Interpolates polynomial f, such that f(x[i]) = y[i].
// O(n log^2).
static polynom_t<mint> interpolate(const std::vector<mint> &x, const std::vector<mint> &y) {
assert(x.size() == y.size());
return multipoint_evaluation_tree(x.size(), x).interpolate(y);
}
polynom_t<mint> operator-() const {
return polynom_t(*this) * mint(-1);
}
polynom_t<mint>& operator+=(const polynom_t<mint> &another) {
if (size() < another.size()) {
resize(another.size());
}
for (int i = 0; i < static_cast<int>(another.size()); i++) {
(*this)[i] += another[i];
}
return *this;
}
polynom_t<mint>& operator-=(const polynom_t<mint> &another) {
if (size() < another.size()) {
resize(another.size());
}
for (int i = 0; i < static_cast<int>(another.size()); i++) {
(*this)[i] -= another[i];
}
return *this;
}
polynom_t<mint>& operator*=(const polynom_t<mint> &another) {
if (empty() || another.empty()) {
clear();
return *this;
}
if (std::min(size(), another.size()) <= MUL_MIN_CUT
|| std::max(size(), another.size()) <= MUL_MAX_CUT) {
polynom_t<mint> product(int(size() + another.size()) - 1);
for (int i = 0; i < static_cast<int>(size()); i++) {
for (int j = 0; j < static_cast<int>(another.size()); j++) {
product[i + j] += (*this)[i] * another[j];
}
}
product.swap(*this);
return *this;
}
int real_size = static_cast<int>(size() + another.size()) - 1, n = 1;
while (n < real_size) {
n <<= 1;
}
if ((*this) == another) {
resize(n);
fft(*this);
for (int i = 0; i < n; i++) {
(*this)[i] *= (*this)[i];
}
} else {
resize(n);
polynom_t b = another;
b.resize(n);
fft(*this), fft(b);
for (int i = 0; i < n; i++) {
(*this)[i] *= b[i];
}
}
inv_fft(*this);
return resize(real_size);
}
// Division with remainder.
// O(n log).
polynom_t<mint>& operator/=(const polynom_t<mint> &another) {
polynom_t<mint> a(*this), b(another);
a.normalize(), b.normalize();
assert(!b.empty());
int n = a.size(), m = b.size();
if (n < m) {
return *this = {};
}
if (n <= DIV_N_CUT || m <= DIV_M_CUT) {
polynom_t<mint> quotient(n - m + 1);
mint inv_b = mint(1) / b.back();
for (int i = n - 1; i >= m - 1; i--) {
int pos = i - m + 1;
quotient[pos] = a[i] * inv_b;
for (int j = 0; j < m; j++) {
a[pos + j] -= b[j] * quotient[pos];
}
}
quotient.normalize();
return *this = quotient;
}
std::reverse(a.begin(), a.end());
std::reverse(b.begin(), b.end());
polynom_t<mint> quotient = a * b.inv(n - m + 1);
quotient.resize(n - m + 1);
std::reverse(quotient.begin(), quotient.end());
quotient.normalize();
return *this = quotient;
}
// O(n log).
polynom_t<mint>& operator%=(const polynom_t<mint> &another) {
*this -= (*this) / another * another;
normalize();
return *this;
}
// Returns derivative.
// O(n).
polynom_t<mint> derivative() const {
polynom_t<mint> der(std::max(0, static_cast<int>(size()) - 1));
for (int i = 0; i < static_cast<int>(der.size()); i++) {
der[i] = mint(i + 1) * (*this)[i + 1];
}
return der;
}
// Returns integral.
// O(n).
polynom_t<mint> integral(const mint &constant = mint(0)) const {
polynom_t<mint> in(size() + 1);
in[0] = constant;
for (int i = 1; i < static_cast<int>(in.size()); i++) {
in[i] = (*this)[i - 1] * fft_data.inv(i);
}
return in;
}
// Returns p^{-1} modulo x^{degree}.
// O(n log).
polynom_t<mint> inv(int degree) const {
assert(!empty() && (*this)[0] != mint(0) && "polynom is not invertable");
int result_size = 1;
while (result_size < degree) {
result_size <<= 1;
}
polynom_t<mint> inv(result_size);
inv[0] = (*this)[0].inv();
int brute_calculation = std::min(degree, INV_BRUTE_FORCE_SIZE);
mint start_inv = (*this)[0].inv();
polynom_t<mint> have(brute_calculation);
for (int i = 0; i < brute_calculation; i++) {
inv[i] = ((i == 0 ? 1 : 0) - have[i]) * start_inv;
int steps = std::min<int>(size(), have.size() - i);
for (int j = 0; j < steps; j++) {
have[i + j] += inv[i] * (*this)[j];
}
}
polynom_t<mint> this_copy, inv_copy;
this_copy.reserve(result_size);
inv_copy.reserve(result_size);
for (int power = brute_calculation; power < degree; power <<= 1) {
this_copy.resize(power << 1);
std::fill(this_copy.begin() + std::min<int>(size(), power << 1), this_copy.end(), 0);
std::copy(begin(), begin() + std::min<int>(size(), power << 1), this_copy.begin());
inv_copy.resize(power << 1);
std::copy(inv.begin(), inv.begin() + power, inv_copy.begin());
fft(this_copy);
fft(inv_copy);
for (int i = 0; i < (power << 1); i++) {
this_copy[i] *= inv_copy[i];
}
inv_fft(this_copy);
std::fill(this_copy.begin(), this_copy.begin() + power, 0);
fft(this_copy);
for (int i = 0; i < (power << 1); i++) {
this_copy[i] *= inv_copy[i];
}
inv_fft(this_copy);
for (int i = power; i < (power << 1); i++) {
inv[i] = -this_copy[i];
}
}
return inv.resize(degree);
}
// Returns log(p) modulo x^{degree}.
// O(n log).
polynom_t<mint> log(int degree) const {
assert(!empty() && (*this)[0] == mint(1) && "log is not defined");
return (derivative().resize(std::min<int>(degree, size())) * inv(degree)).resize(degree).integral(mint(0)).resize(degree);
}
// Returns exp(p) modulo x^{degree}.
// O(n log).
polynom_t<mint> exp(int degree) const {
assert(!empty() && (*this)[0] == mint(0) && "exp is not defined");
int result_size = 1;
while (result_size < degree) {
result_size <<= 1;
}
polynom_t<mint> exp, exp_fft, inv_exp, inv_exp_fft;
exp = exp_fft = inv_exp = inv_exp_fft = {1};
polynom_t<mint> h(1), q(1);
auto diff = derivative().resize(result_size);
for (int power = 1; power < degree; power <<= 1) {
exp_fft = polynom_t<mint>(exp).resize(power << 1);
fft(exp_fft);
for (int i = 0; i < power; i++) {
h[i] = exp_fft[i] * inv_exp_fft[i];
}
inv_fft(h);
std::fill(h.begin(), h.begin() + ((power + 1) >> 1), 0);
fft(h);
for (int i = 0; i < power; i++) {
h[i] *= -inv_exp_fft[i];
}
inv_fft(h);
for (int i = (power + 1) / 2; i < power; i++) {
inv_exp.push_back(h[i]);
}
inv_exp_fft = polynom_t<mint>(inv_exp).resize(power << 1);
fft(inv_exp_fft);
h.assign(power, 0);
std::copy(diff.begin(), diff.begin() + power - 1, h.begin());
fft(h);
for (int i = 0; i < power; i++) {
q[i] = exp_fft[i] * h[i];
}
inv_fft(q);
h.resize(power << 1);
for (int i = 1; i < power; i++) {
h[i] = exp[i] * i;
}
for (int i = 0; i < power; i++) {
h[i + 1] -= q[i];
}
h[0] = -q[power - 1];
fft(h);
q.resize(power << 1);
for (int i = 0; i < (power << 1); i++) {
q[i] = inv_exp_fft[i] * h[i];
}
inv_fft(q);
if (static_cast<int>(size()) >= power) {
h.assign(begin() + power, begin() + std::min<int>(power << 1, size()));
} else {
h.clear();
}
h.resize(power);
for (int i = power - 1; i >= 0; i--) {
h[i] -= q[i] * fft_data.inv(i + power);
}
h.resize(power << 1);
fft(h);
for (int i = 0; i < (power << 1); i++) {
q[i] = exp_fft[i] * h[i];
}
inv_fft(q);
exp.resize(power << 1);
for (int i = 0; i < power; i++) {
exp[power + i] = q[i];
}
}
return exp.resize(degree);
}
// Returns p^{d} modulo x^{degree}.
// O(n log).
polynom_t<mint> power(int64_t d, int degree) const {
if (!d || !degree) {
return polynom_t<mint>{1}.resize(degree);
}
int pos = 0;
while (pos < static_cast<int>(size()) && (*this)[pos] == mint(0)) {
pos++;
}
if (pos == static_cast<int>(size()) || pos >= (degree + d - 1) / d) {
return polynom_t<mint>(degree);
}
int coeffs_left = degree - d * pos;
polynom_t<mint> result = ((polynom_t<mint>(begin() + pos, end()) / (*this)[pos]).log(coeffs_left)
* mint(d)).exp(coeffs_left) * (*this)[pos].power(d);
result.resize(degree);
for (int i = degree - 1; i - (degree - coeffs_left) >= 0; i--) {
result[i] = result[i - (degree - coeffs_left)];
}
std::fill(result.begin(), result.end() - coeffs_left, mint(0));
return result;
}
// Returns f(x + c).
// O(n log).
polynom_t<mint> change_of_variable(mint c) const {
int n = size();
polynom_t<mint> p(n), q(n);
mint fact = 1, ifact = 1, power = 1;
for (int i = 0; i < n; i++) {
p[n - 1 - i] = (*this)[i] * fact;
q[i] = power * ifact;
fact *= i + 1;
ifact *= fft_data.inv(i + 1);
power *= c;
}
auto result = p * q;
result.resize(n);
std::reverse(result.begin(), result.end());
ifact = 1;
for (int i = 0; i < n; i++) {
result[i] *= ifact;
ifact *= fft_data.inv(i + 1);
}
return result;
}
template<typename V>
std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>&> operator*=(const V &value) {
for (auto &x : *this) {
x *= value;
}
return *this;
}
template<typename V>
std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>&> operator/=(const V &value) {
for (auto &x : *this) {
x /= value;
}
return *this;
}
friend std::ostream& operator<<(std::ostream &out, const polynom_t<mint> &p) {
for (int i = 0; i < static_cast<int>(p.size()); i++) {
if (i) out << ' ';
out << p[i];
}
return out;
}
friend polynom_t<mint> operator+(const polynom_t<mint> &p, const polynom_t<mint> &q) {
return polynom_t(p) += q;
}
friend polynom_t<mint> operator-(const polynom_t<mint> &p, const polynom_t<mint> &q) {
return polynom_t(p) -= q;
}
friend polynom_t<mint> operator*(const polynom_t<mint> &p, const polynom_t<mint> &q) {
return polynom_t(p) *= q;
}
friend polynom_t<mint> operator/(const polynom_t<mint> &p, const polynom_t<mint> &q) {
return polynom_t(p) /= q;
}
friend polynom_t<mint> operator%(const polynom_t<mint> &p, const polynom_t<mint> &q) {
return polynom_t(p) %= q;
}
template<typename V>
friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
operator*(const V &value, const polynom_t<mint> &p) {
return polynom_t<mint>(p) *= value;
}
template<typename V>
friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
operator*(const polynom_t<mint> &p, const V &value) {
return polynom_t<mint>(p) *= value;
}
template<typename V>
friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
operator/(const V &value, const polynom_t<mint> &p) {
return polynom_t<mint>(p) /= value;
}
template<typename V>
friend std::enable_if_t<!std::is_same_v<V, polynom_t<mint>>, polynom_t<mint>>
operator/(const polynom_t<mint> &p, const V &value) {
return polynom_t<mint>(p) /= value;
}
};
using polynom = polynom_t<mint>;
/*
! WARNING: MOD must be prime.
* Define modular int class above it.
* No need to run any init function, it dynamically resizes the data.
*/
namespace combinatorics {
std::vector<mint> fact_, ifact_, inv_;
void resize_data(int size) {
if (fact_.empty()) {
fact_ = {mint(1), mint(1)};
ifact_ = {mint(1), mint(1)};
inv_ = {mint(0), mint(1)};
}
for (int pos = fact_.size(); pos <= size; pos++) {
fact_.push_back(fact_.back() * mint(pos));
inv_.push_back(-inv_[MOD % pos] * mint(MOD / pos));
ifact_.push_back(ifact_.back() * inv_[pos]);
}
}
struct combinatorics_info {
std::vector<mint> &data;
combinatorics_info(std::vector<mint> &data) : data(data) {}
mint operator[](int pos) {
if (pos >= static_cast<int>(data.size())) {
resize_data(pos);
}
return data[pos];
}
} fact(fact_), ifact(ifact_), inv(inv_);
// From n choose k.
// O(max(n)) in total.
mint choose(int n, int k) {
if (n < k || k < 0 || n < 0) {
return mint(0);
}
return fact[n] * ifact[k] * ifact[n - k];
}
// From n choose k.
// O(min(k, n - k)).
mint choose_slow(int64_t n, int64_t k) {
if (n < k || k < 0 || n < 0) {
return mint(0);
}
k = std::min(k, n - k);
mint result = 1;
for (int i = k; i >= 1; i--) {
result *= (n - i + 1);
result *= inv[i];
}
return result;
}
// Number of balanced bracket sequences with n open and m closing brackets.
mint catalan(int n, int m) {
if (m > n || m < 0) {
return mint(0);
}
return choose(n + m, m) - choose(n + m, m - 1);
}
// Number of balanced bracket sequences with n open and closing brackets.
mint catalan(int n) {
return catalan(n, n);
}
} // namespace combinatorics
using namespace combinatorics;
mint solve(int n, int m) {
return choose(n + m - 2, n - 1) * choose(n + m - 2, n - 1) - choose(n + m - 2, n) * choose(n + m - 2, m);
}
mint solve_pairs(int n, int m) {
mint ans = 0;
// for (int y = 0; y < n; y++) {
// for (int y2 = y + 1; y2 <= n; y2++) {
// ans += solve(y2 - y, m);
// }
// }
for (int d = 1; d <= n; d++) {
ans += (n - d + 1) * solve(d, m);
}
return ans;
}
mint solve_borders(int n, int m) {
mint ans = 0;
// for (int y = 1; y <= n; y++) {
// for (int x = 1; x < m; x++) {
// ans += choose(x + y - 2, x - 1) * choose(x + y - 2, x - 1);
// ans -= choose(x + y - 2, x) * choose(x + y - 2, y);
// }
// }
{
polynom p(n + 1);
for (int y = 1; y <= n; y++) {
p[y] = ifact[y - 1] * ifact[y - 1];
}
polynom q(m);
for (int x = 1; x < m; x++) {
q[x] = ifact[x - 1] * ifact[x - 1];
}
auto f = p * q;
for (int s = 2; s < len(f); s++) {
ans += f[s] * fact[s - 2] * fact[s - 2];
}
}
{
polynom p(n + 1);
for (int y = 1; y <= n; y++) {
p[y] = ifact[y - 2] * ifact[y];
}
polynom q(m);
for (int x = 1; x < m; x++) {
q[x]= ifact[x - 2] * ifact[x];
}
auto f = p * q;
for (int s = 2; s < len(f); s++) {
ans -= f[s] * fact[s - 2] * fact[s - 2];
}
}
return ans;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int n, m;
cin >> n >> m;
mint ans = solve_pairs(n, m) + solve_pairs(m, n) + solve_borders(n, m) + solve_borders(m, n);
// for (int y = 0; y < n; y++) {
// for (int x = 1; x < m; x++) {
// ans += solve(x, n - y);
// }
// for (int y2 = y + 1; y2 <= n; y2++) {
// ans += solve(m, y2 - y);
// }
// }
// for (int x = 1; x < m; x++) {
// for (int y = 1; y < n; y++) {
// ans += solve(m - x, y);
// }
// for (int x2 = x + 1; x2 <= m; x2++) {
// ans += solve(x2 - x, n);
// }
// }
for (int y = 1; y < n; y++) {
ans -= solve(m, y);
}
for (int x2 = 1; x2 <= m; x2++) {
ans -= solve(x2, n);
}
cout << ans + 2 << '\n';
}
详细
Test #1:
score: 100
Accepted
time: 0ms
memory: 3540kb
input:
2 2
output:
11
result:
ok "11"
Test #2:
score: 0
Accepted
time: 0ms
memory: 3736kb
input:
20 23
output:
521442928
result:
ok "521442928"
Test #3:
score: 0
Accepted
time: 133ms
memory: 15288kb
input:
200000 200000
output:
411160917
result:
ok "411160917"
Test #4:
score: 0
Accepted
time: 0ms
memory: 3584kb
input:
8 3
output:
2899
result:
ok "2899"
Test #5:
score: 0
Accepted
time: 0ms
memory: 3568kb
input:
10 9
output:
338037463
result:
ok "338037463"
Test #6:
score: 0
Accepted
time: 0ms
memory: 3528kb
input:
3 3
output:
64
result:
ok "64"
Test #7:
score: 0
Accepted
time: 0ms
memory: 3792kb
input:
9 4
output:
39733
result:
ok "39733"
Test #8:
score: 0
Accepted
time: 0ms
memory: 3536kb
input:
36 33
output:
545587245
result:
ok "545587245"
Test #9:
score: 0
Accepted
time: 0ms
memory: 3744kb
input:
35 39
output:
62117944
result:
ok "62117944"
Test #10:
score: 0
Accepted
time: 0ms
memory: 3536kb
input:
48 10
output:
264659761
result:
ok "264659761"
Test #11:
score: 0
Accepted
time: 0ms
memory: 3572kb
input:
46 30
output:
880000821
result:
ok "880000821"
Test #12:
score: 0
Accepted
time: 0ms
memory: 3536kb
input:
25 24
output:
280799864
result:
ok "280799864"
Test #13:
score: 0
Accepted
time: 0ms
memory: 3584kb
input:
17 10
output:
624958192
result:
ok "624958192"
Test #14:
score: 0
Accepted
time: 4ms
memory: 3712kb
input:
4608 9241
output:
322218996
result:
ok "322218996"
Test #15:
score: 0
Accepted
time: 4ms
memory: 3724kb
input:
3665 6137
output:
537704652
result:
ok "537704652"
Test #16:
score: 0
Accepted
time: 0ms
memory: 3752kb
input:
4192 6186
output:
971816887
result:
ok "971816887"
Test #17:
score: 0
Accepted
time: 4ms
memory: 3660kb
input:
4562 4403
output:
867628411
result:
ok "867628411"
Test #18:
score: 0
Accepted
time: 4ms
memory: 3756kb
input:
8726 3237
output:
808804305
result:
ok "808804305"
Test #19:
score: 0
Accepted
time: 4ms
memory: 3736kb
input:
5257 8166
output:
488829288
result:
ok "488829288"
Test #20:
score: 0
Accepted
time: 4ms
memory: 3980kb
input:
8013 7958
output:
215666893
result:
ok "215666893"
Test #21:
score: 0
Accepted
time: 4ms
memory: 3896kb
input:
8837 5868
output:
239261227
result:
ok "239261227"
Test #22:
score: 0
Accepted
time: 4ms
memory: 3768kb
input:
8917 5492
output:
706653412
result:
ok "706653412"
Test #23:
score: 0
Accepted
time: 4ms
memory: 3664kb
input:
9628 5378
output:
753685501
result:
ok "753685501"
Test #24:
score: 0
Accepted
time: 122ms
memory: 14200kb
input:
163762 183794
output:
141157510
result:
ok "141157510"
Test #25:
score: 0
Accepted
time: 58ms
memory: 8464kb
input:
83512 82743
output:
114622013
result:
ok "114622013"
Test #26:
score: 0
Accepted
time: 54ms
memory: 8072kb
input:
84692 56473
output:
263907717
result:
ok "263907717"
Test #27:
score: 0
Accepted
time: 33ms
memory: 6140kb
input:
31827 74195
output:
200356808
result:
ok "200356808"
Test #28:
score: 0
Accepted
time: 137ms
memory: 14268kb
input:
189921 163932
output:
845151158
result:
ok "845151158"
Test #29:
score: 0
Accepted
time: 69ms
memory: 9280kb
input:
27157 177990
output:
847356039
result:
ok "847356039"
Test #30:
score: 0
Accepted
time: 58ms
memory: 8652kb
input:
136835 39390
output:
962822638
result:
ok "962822638"
Test #31:
score: 0
Accepted
time: 64ms
memory: 7964kb
input:
118610 18795
output:
243423874
result:
ok "243423874"
Test #32:
score: 0
Accepted
time: 60ms
memory: 8084kb
input:
122070 19995
output:
531055604
result:
ok "531055604"
Test #33:
score: 0
Accepted
time: 70ms
memory: 9528kb
input:
20031 195670
output:
483162363
result:
ok "483162363"
Test #34:
score: 0
Accepted
time: 135ms
memory: 15312kb
input:
199992 199992
output:
262099623
result:
ok "262099623"
Test #35:
score: 0
Accepted
time: 134ms
memory: 15292kb
input:
200000 199992
output:
477266520
result:
ok "477266520"
Test #36:
score: 0
Accepted
time: 135ms
memory: 15300kb
input:
199999 199996
output:
165483205
result:
ok "165483205"
Test #37:
score: 0
Accepted
time: 0ms
memory: 3608kb
input:
1 1
output:
3
result:
ok "3"
Test #38:
score: 0
Accepted
time: 6ms
memory: 5596kb
input:
1 100000
output:
8828237
result:
ok "8828237"
Test #39:
score: 0
Accepted
time: 3ms
memory: 5572kb
input:
100000 2
output:
263711286
result:
ok "263711286"
Test #40:
score: 0
Accepted
time: 0ms
memory: 3540kb
input:
50 50
output:
634767411
result:
ok "634767411"
Extra Test:
score: 0
Extra Test Passed