#pragma GCC optimize("O3")
#pragma GCC target("avx2,bmi")
#include <immintrin.h>
#include <stdint.h>
#include <algorithm>
#include <cassert>
#include <iostream>
using u32 = uint32_t;
using u64 = uint64_t;
namespace simd {
using i128 = __m128i;
using i256 = __m256i;
using u32x8 = u32 __attribute__((vector_size(32)));
using u64x4 = u64 __attribute__((vector_size(32)));
u32x8 load_u32x8(u32 *ptr) {
return (u32x8)(_mm256_load_si256((i256 *)ptr));
}
u32x8 loadu_u32x8(u32 *ptr) {
return (u32x8)(_mm256_loadu_si256((i256 *)ptr));
}
void store_u32x8(u32 *ptr, u32x8 val) {
_mm256_store_si256((i256 *)ptr, (i256)(val));
}
void storeu_u32x8(u32 *ptr, u32x8 val) {
_mm256_storeu_si256((i256 *)ptr, (i256)(val));
}
u64x4 load_u64x4(u64 *ptr) {
return (u64x4)(_mm256_load_si256((i256 *)ptr));
}
u64x4 loadu_u64x4(u64 *ptr) {
return (u64x4)(_mm256_loadu_si256((i256 *)ptr));
}
void store_u64x4(u64 *ptr, u64x4 val) {
_mm256_store_si256((i256 *)ptr, (i256)(val));
}
void storeu_u64x4(u64 *ptr, u64x4 val) {
_mm256_storeu_si256((i256 *)ptr, (i256)(val));
}
u32x8 set1_u32x8(u32 val) {
return (u32x8)(_mm256_set1_epi32(val));
}
u64x4 set1_u64x4(u64 val) {
return (u64x4)(_mm256_set1_epi64x(val));
}
u32x8 setr_u32x8(u32 a0, u32 a1, u32 a2, u32 a3, u32 a4, u32 a5, u32 a6, u32 a7) {
return (u32x8)(_mm256_setr_epi32(a0, a1, a2, a3, a4, a5, a6, a7));
}
u64x4 setr_u64x4(u64 a0, u64 a1, u64 a2, u64 a3) {
return (u64x4)(_mm256_setr_epi64x(a0, a1, a2, a3));
}
template <int imm8>
u32x8 shuffle_u32x8(u32x8 val) {
return (u32x8)(_mm256_shuffle_epi32((i256)(val), imm8));
}
u32x8 permute_u32x8(u32x8 val, u32x8 p) {
return (u32x8)(_mm256_permutevar8x32_epi32((i256)(val), (i256)(p)));
}
template <int imm8>
u32x8 permute_u32x8_epi128(u32x8 a, u32x8 b) {
return (u32x8)(_mm256_permute2x128_si256((i256)(a), (i256)(b), imm8));
}
template <int imm8>
u32x8 blend_u32x8(u32x8 a, u32x8 b) {
return (u32x8)(_mm256_blend_epi32((i256)(a), (i256)(b), imm8));
}
template <int imm8>
u32x8 shift_left_u32x8_epi128(u32x8 val) {
return (u32x8)(_mm256_bslli_epi128((i256)(val), imm8));
}
template <int imm8>
u32x8 shift_right_u32x8_epi128(u32x8 val) {
return (u32x8)(_mm256_bsrli_epi128((i256)(val), imm8));
}
u32x8 shift_left_u32x8_epi64(u32x8 val, int imm8) {
return (u32x8)(_mm256_slli_epi64((i256)(val), imm8));
}
u32x8 shift_right_u32x8_epi64(u32x8 val, int imm8) {
return (u32x8)(_mm256_srli_epi64((i256)(val), imm8));
}
u32x8 min_u32x8(u32x8 a, u32x8 b) {
return (u32x8)(_mm256_min_epu32((i256)(a), (i256)(b)));
}
u32x8 mul64_u32x8(u32x8 a, u32x8 b) {
return (u32x8)(_mm256_mul_epu32((i256)(a), (i256)(b)));
}
}; // namespace simd
using namespace simd;
// Montgomery32
struct Montgomery {
u32 mod;
u32 mod2; // 2 * mod
u32 n_inv; // n_inv * mod == -1 (mod 2^32)
u32 r; // 2^32 % mod
u32 r2; // (2^32) ^ 2 % mod
Montgomery() = default;
Montgomery(u32 mod) : mod(mod) {
assert(mod % 2);
assert(mod < (1 << 30));
n_inv = 1;
for (int i = 0; i < 5; i++) {
n_inv *= 2u + n_inv * mod;
}
assert(n_inv * mod == -1u);
mod2 = 2 * mod;
r = (1ULL << 32) % mod;
r2 = r * u64(r) % mod;
}
u32 shrink(u32 val) const {
return std::min(val, val - mod);
}
u32 shrink2(u32 val) const {
return std::min(val, val - mod2);
}
u32 shrink_n(u32 val) const {
return std::min(val, val + mod);
}
u32 shrink2_n(u32 val) const {
return std::min(val, val + mod2);
}
// a * b should be in [0, 2**32 * mod)
// result in [0, 2 * mod) <false>
// result in [0, mod) <true>
template <bool strict = false>
u32 mul(u32 a, u32 b) const {
u64 val = u64(a) * b;
u32 res = (val + u32(val) * n_inv * u64(mod)) >> 32;
if constexpr (strict)
res = shrink(res);
return res;
}
[[gnu::noinline]] u32 power(u32 b, u32 e) const {
b = mul(b, r2);
u32 r = 1;
for (; e > 0; e >>= 1) {
if (e & 1) {
r = mul(r, b);
}
b = mul(b, b);
}
r = shrink(r);
return r;
}
};
// Montgomery32
struct Montgomery_simd {
alignas(32) u32x8 mod;
alignas(32) u32x8 mod2; // 2 * mod
alignas(32) u32x8 n_inv; // n_inv * mod == -1 (mod 2^32)
alignas(32) u32x8 r; // 2^32 % mod
alignas(32) u32x8 r2; // (2^32) ^ 2 % mod
Montgomery_simd() = default;
Montgomery_simd(u32 md) {
Montgomery mt(md);
mod = set1_u32x8(mt.mod);
mod2 = set1_u32x8(mt.mod2);
n_inv = set1_u32x8(mt.n_inv);
r = set1_u32x8(mt.r);
r2 = set1_u32x8(mt.r2);
}
u32x8 shrink(u32x8 val) const {
return min_u32x8(val, val - mod);
}
u32x8 shrink2(u32x8 val) const {
return min_u32x8(val, val - mod2);
}
u32x8 shrink_n(u32x8 val) const {
return min_u32x8(val, val + mod);
}
u32x8 shrink2_n(u32x8 val) const {
return min_u32x8(val, val + mod2);
}
template <bool strict = false>
u64x4 reduce(u64x4 val) const {
val = (u64x4)shift_right_u32x8_epi64(u32x8(val + (u64x4)mul64_u32x8(mul64_u32x8((u32x8)val, n_inv), mod)), 32);
if constexpr (strict) {
val = (u64x4)shrink((u32x8)val);
}
return val;
}
template <bool strict = false>
u32x8 reduce(u64x4 x0246, u64x4 x1357) const {
u32x8 x0246_ninv = mul64_u32x8((u32x8)x0246, n_inv);
u32x8 x1357_ninv = mul64_u32x8((u32x8)x1357, n_inv);
u32x8 res = blend_u32x8<0b10'10'10'10>(shift_right_u32x8_epi128<4>(u32x8((u64x4)x0246 + (u64x4)mul64_u32x8(x0246_ninv, mod))),
u32x8((u64x4)x1357 + (u64x4)mul64_u32x8(x1357_ninv, mod)));
if constexpr (strict)
res = shrink(res);
return res;
}
// a * b should be in [0, 2**32 * mod)
// result in [0, 2 * mod) <false>
// result in [0, mod) <true>
//
// set eq_b to <true> to slightly improve performance, if all elements of b are equal
template <bool strict = false, bool eq_b = false>
u32x8 mul(u32x8 a, u32x8 b) const {
u32x8 x0246 = mul64_u32x8(a, b);
u32x8 b_sh = b;
if constexpr (!eq_b) {
b_sh = shift_right_u32x8_epi128<4>(b);
}
u32x8 x1357 = mul64_u32x8(shift_right_u32x8_epi128<4>(a), b_sh);
return reduce<strict>((u64x4)x0246, (u64x4)x1357);
}
// a * b should be in [0, 2**32 * mod)
// puts result in high 32-bit of each 64-bit word
// result in [0, 2 * mod) <false>
// result in [0, mod) <true>
template <bool strict = false>
u64x4 mul_to_hi(u64x4 a, u64x4 b) const {
u32x8 val = mul64_u32x8((u32x8)a, (u32x8)b);
u32x8 val_ninv = mul64_u32x8(val, n_inv);
u32x8 res = u32x8(u64x4(val) + u64x4(mul64_u32x8(val_ninv, mod)));
if constexpr (strict)
res = shrink(res);
return (u64x4)res;
}
// a * b should be in [0, 2**32 * mod)
// result in [0, 2 * mod) <false>
// result in [0, mod) <true>
template <bool strict = false>
u64x4 mul(u64x4 a, u64x4 b) const {
u32x8 val = mul64_u32x8((u32x8)a, (u32x8)b);
return reduce<strict>((u64x4)val);
}
};
template <size_t scale, size_t... S, typename Functor>
__attribute__((always_inline)) constexpr void static_foreach_seq(Functor function, std::index_sequence<S...>) {
((function(std::integral_constant<size_t, S * scale>())), ...);
}
template <size_t Size, size_t scale = 1, typename Functor>
__attribute__((always_inline)) constexpr void static_for(Functor functor) {
return static_foreach_seq<scale>(functor, std::make_index_sequence<Size>());
}
struct cum_timer {
clock_t beg;
std::string s;
cum_timer(std::string s) : s(s) {
reset();
}
void reset() {
beg = clock();
}
double elapsed(bool reset = false) {
clock_t clk = clock();
double res = (clk - beg) * 1.0 / CLOCKS_PER_SEC;
if (reset) {
beg = clk;
}
return res;
}
void print() {
std::cerr << s << ": " << elapsed() << std::endl;
}
~cum_timer() {
print();
}
};
#include <array>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <vector>
struct NTT {
alignas(32) const Montgomery_simd mts;
const Montgomery mt;
u32 mod, g;
[[gnu::noinline]] u32 power(u32 base, u32 exp) const {
const auto mt = this->mt; // ! to put Montgomery constants in registers
u32 res = mt.r;
for (; exp > 0; exp >>= 1) {
if (exp & 1) {
res = mt.mul(res, base);
}
base = mt.mul(base, base);
}
return mt.shrink(res);
}
// mod should be prime
u32 find_pr_root(u32 mod) const {
u32 m = mod - 1;
std::vector<u32> vec;
for (u32 i = 2; u64(i) * i <= m; i++) {
if (m % i == 0) {
vec.push_back(i);
do {
m /= i;
} while (m % i == 0);
}
}
if (m != 1) {
vec.push_back(m);
}
for (u32 i = 2;; i++) {
if (std::all_of(vec.begin(), vec.end(), [&](u32 f) { return mt.r != power(mt.mul(i, mt.r2), (mod - 1) / f); })) {
return i;
}
}
}
u32x8 get_powers_u32x8(u32 x) const {
u32 x2 = mt.mul<true>(x, x);
u32 x3 = mt.mul<true>(x, x2);
u32 x4 = mt.mul<true>(x2, x2);
return setr_u32x8(mt.r, x, x2, x3, x4, mt.mul<true>(x2, x3), mt.mul<true>(x3, x3), mt.mul<true>(x4, x3));
}
static constexpr int LG = 30;
alignas(32) u32 w[4], w_r[4];
alignas(32) u64x4 w_cum_x4[LG], w_rcum_x4[LG];
alignas(32) u32x8 w_cum_x8[LG], w_rcum_x8[LG];
NTT(u32 mod = 998'244'353) : mt(mod), mts(mod), mod(mod) {
const auto mt = this->mt; // ! to put Montgomery constants in registers
g = mt.mul<true>(mt.r2, find_pr_root(mod));
for (int i = 0; i < LG; i++) {
u32 f = power(g, (mod - 1) >> i + 3);
u32 res = mt.mul<true>(f, power(power(f, (1 << i + 1) - 2), mod - 2));
u32 res_r = power(res, mod - 2);
w_cum_x4[i] = setr_u64x4(res, power(res, 2), res, power(res, 3));
w_rcum_x4[i] = setr_u64x4(res_r, power(res_r, 2), res_r, power(res_r, 3));
}
for (int i = 0; i < LG; i++) {
u32 f = power(g, (mod - 1) >> i + 4);
f = mt.mul<true>(f, power(power(f, (1 << i + 1) - 2), mod - 2));
u32 f1 = f;
u32 f2 = mt.mul<true>(f1, f1);
u32 f4 = mt.mul<true>(f2, f2);
w_cum_x8[i] = setr_u32x8(f, f2, f, f4, f, f2, f, f4);
u32 f_r = power(f, mod - 2);
w_rcum_x8[i][0] = mt.r;
for (int j = 1; j < 8; j++) {
w_rcum_x8[i][j] = mt.mul<true>(w_rcum_x8[i][j - 1], f_r);
}
}
u32 w18 = power(g, (mod - 1) / 8);
w[0] = mt.r, w[1] = power(w18, 2), w[2] = w18, w[3] = power(w18, 3);
u32 w78 = power(w18, 7); // == w18 ^ (-1)
w_r[0] = mt.r, w_r[1] = power(w78, 2), w_r[2] = w78, w_r[3] = power(w78, 3);
}
template <bool trivial = false>
static void butterfly_forward_x4(u32 *addr_0, u32 *addr_1, u32 *addr_2, u32 *addr_3,
u32x8 w_1, u32x8 w1, u32x8 w2, u32x8 w3,
const Montgomery_simd &mts) {
u32x8 a = load_u32x8(addr_0);
u32x8 b = load_u32x8(addr_1);
u32x8 c = load_u32x8(addr_2);
u32x8 d = load_u32x8(addr_3);
a = mts.shrink2(a);
if constexpr (!trivial) {
b = mts.mul<false, true>(b, w1),
c = mts.mul<false, true>(c, w2),
d = mts.mul<false, true>(d, w3);
} else {
b = mts.shrink2(b),
c = mts.shrink2(c),
d = mts.shrink2(d);
}
u32x8 a1 = mts.shrink2(a + c),
b1 = mts.shrink2(b + d),
c1 = mts.shrink2_n(a - c),
d1 = mts.mul<false, true>(b + mts.mod2 - d, w_1);
store_u32x8(addr_0, a1 + b1);
store_u32x8(addr_1, a1 + mts.mod2 - b1);
store_u32x8(addr_2, c1 + d1);
store_u32x8(addr_3, c1 + mts.mod2 - d1);
}
template <bool trivial = false, bool multiply_by_cum = false>
static void butterfly_inverse_x4(u32 *addr_0, u32 *addr_1, u32 *addr_2, u32 *addr_3,
u32x8 w_1, u32x8 w1, u32x8 w2, u32x8 w3,
const Montgomery_simd &mts, u32x8 cum = u32x8()) {
u32x8 a = load_u32x8(addr_0);
u32x8 b = load_u32x8(addr_1);
u32x8 c = load_u32x8(addr_2);
u32x8 d = load_u32x8(addr_3);
u32x8 a1 = mts.shrink2(a + b),
b1 = mts.shrink2_n(a - b),
c1 = mts.shrink2(c + d),
d1 = mts.mul<false, true>(c + mts.mod2 - d, w_1);
if constexpr (!trivial || multiply_by_cum) {
if (!multiply_by_cum) {
store_u32x8(addr_0, mts.shrink2(a1 + c1));
} else {
store_u32x8(addr_0, mts.mul<false, true>(a1 + c1, cum));
}
store_u32x8(addr_1, mts.mul<false, true>(b1 + d1, w1));
store_u32x8(addr_2, mts.mul<false, true>(a1 + mts.mod2 - c1, w2));
store_u32x8(addr_3, mts.mul<false, true>(b1 + mts.mod2 - d1, w3));
} else {
store_u32x8(addr_0, mts.shrink(mts.shrink2(a1 + c1)));
store_u32x8(addr_1, mts.shrink(mts.shrink2(b1 + d1)));
store_u32x8(addr_2, mts.shrink(mts.shrink2(a1 + mts.mod2 - c1)));
store_u32x8(addr_3, mts.shrink(mts.shrink2(b1 + mts.mod2 - d1)));
}
}
// input data[i] in [0, 2 * mod)
// output data[i] in [0, 4 * mod)
[[gnu::noinline]] __attribute__((optimize("O3"))) void ntt(int lg, u32 *data) const {
const auto mt = this->mt; // ! to put Montgomery constants in registers
const auto mts = this->mts; // ! to put Montgomery constants in registers
int n = 1 << lg;
int k = lg;
if (lg % 2 == 0) {
for (int i = 0; i < n / 2; i += 8) {
u32x8 a = load_u32x8(data + i);
u32x8 b = load_u32x8(data + n / 2 + i);
store_u32x8(data + i, mts.shrink2(a + b));
store_u32x8(data + n / 2 + i, mts.shrink2_n(a - b));
}
k--;
}
assert(k % 2 == 1);
for (; k > 4; k -= 2) {
u64x4 wj_cum = set1_u64x4(mt.r);
u32x8 w_1 = set1_u32x8(w[1]);
for (int i = 0; i < n; i += (1 << k)) {
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_cum_x4[__builtin_ctz(~(i >> k))]);
for (int j = 0; j < (1 << k - 2); j += 8) {
butterfly_forward_x4(data + i + 0 * (1 << k - 2) + j,
data + i + 1 * (1 << k - 2) + j,
data + i + 2 * (1 << k - 2) + j,
data + i + 3 * (1 << k - 2) + j,
w_1, w1, w2, w3, mts);
}
}
}
assert(k == 3);
{
// * { w, w^2, w * w_1, w^4, | w * w_2, w^2 * w_1, w * w_3, w^4 }
u32x8 cum = setr_u32x8(w[0], w[0], w[1], w[0], w[2], w[1], w[3], w[0]);
int n_8 = n / 8;
for (int i = 0; i < n_8; i++) {
u32x8 vec = load_u32x8(data + i * 8);
u32x8 wj0 = shuffle_u32x8<0b11'11'11'11>(cum);
u32x8 wj1 = shuffle_u32x8<0b01'01'01'01>(cum);
u32x8 wj2 = cum; // no shuffle needed
u32x8 bw;
vec = mts.shrink2(vec);
bw = permute_u32x8((u32x8)mts.mul_to_hi((u64x4)wj0, (u64x4)permute_u32x8(vec, setr_u32x8(4, -1, 5, -1, 6, -1, 7, -1))), setr_u32x8(1, 3, 5, 7, 1, 3, 5, 7));
vec = permute_u32x8_epi128<0>(vec, vec) + blend_u32x8<0b11'11'00'00>(bw, mts.mod2 - bw);
vec = mts.shrink2(vec);
bw = shuffle_u32x8<0b11'01'11'01>((u32x8)mts.mul_to_hi((u64x4)wj1, (u64x4)shuffle_u32x8<0b00'11'00'10>(vec)));
vec = shuffle_u32x8<0b01'00'01'00>(vec) + blend_u32x8<0b11'00'11'00>(bw, mts.mod2 - bw);
vec = mts.shrink2(vec);
bw = shuffle_u32x8<0b11'11'01'01>((u32x8)mts.mul_to_hi((u64x4)wj2, (u64x4)shuffle_u32x8<0b00'11'00'01>(vec)));
vec = shuffle_u32x8<0b10'10'00'00>(vec) + blend_u32x8<0b10'10'10'10>(bw, mts.mod2 - bw);
store_u32x8(data + i * 8, vec);
cum = mts.mul(cum, w_cum_x8[__builtin_ctz(~i)]);
}
}
}
// input data[i] in [0, 2 * mod)
// output data[i] in [0, mod)
// fc (if specified) should be in [0, mod)
// if fc is specified everything is multiplied by fc
[[gnu::noinline]] __attribute__((optimize("O3"))) void intt(int lg, u32 *data, u32 fc = -1u) const {
const auto mt = this->mt; // ! to put Montgomery constants in registers
const auto mts = this->mts; // ! to put Montgomery constants in registers
if (fc == -1u) {
fc = mt.r;
}
int n = 1 << lg;
int k = 1;
{
u32x8 cum0 = setr_u32x8(w_r[0], w_r[0], w_r[0], w_r[0], w_r[0], w_r[0], w_r[1], w_r[1]);
u32x8 cum1 = setr_u32x8(w_r[0], w_r[0], w_r[0], w_r[1], w_r[0], w_r[2], w_r[0], w_r[3]);
const u32 inv_2 = mt.mul<true>(mt.r2, (mod + 1) / 2);
u32x8 cum = set1_u32x8(mt.mul<true>(fc, power(inv_2, lg)));
int n_8 = n / 8;
for (int i = 0; i < n_8; i++) {
u32x8 vec = load_u32x8(data + i * 8);
vec = mts.mul(cum1, blend_u32x8<0b10'10'10'10>(vec, mts.mod2 - vec) + shuffle_u32x8<0b10'11'00'01>(vec));
vec = mts.mul(cum0, blend_u32x8<0b11'00'11'00>(vec, mts.mod2 - vec) + shuffle_u32x8<0b01'00'11'10>(vec));
vec = mts.mul(cum, blend_u32x8<0b11'11'00'00>(vec, mts.mod2 - vec) + permute_u32x8_epi128<1>(vec, vec));
store_u32x8(data + i * 8, vec);
cum = mts.mul<true>(cum, w_rcum_x8[__builtin_ctz(~i)]);
}
k += 3;
}
for (; k + 1 <= lg; k += 2) {
u64x4 wj_cum = set1_u64x4(mt.r);
u32x8 w_1 = set1_u32x8(w_r[1]);
for (int i = 0; i < n; i += (1 << k + 1)) {
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_rcum_x4[__builtin_ctz(~(i >> k + 1))]);
for (int j = 0; j < (1 << k - 1); j += 8) {
butterfly_inverse_x4(data + i + 0 * (1 << k - 1) + j,
data + i + 1 * (1 << k - 1) + j,
data + i + 2 * (1 << k - 1) + j,
data + i + 3 * (1 << k - 1) + j,
w_1, w1, w2, w3, mts);
}
}
}
if (k == lg) {
for (int i = 0; i < n / 2; i += 8) {
u32x8 a = load_u32x8(data + i);
u32x8 b = load_u32x8(data + n / 2 + i);
store_u32x8(data + i, mts.shrink(mts.shrink2(a + b)));
store_u32x8(data + n / 2 + i, mts.shrink(mts.shrink2_n(a - b)));
}
} else {
for (int i = 0; i < n; i += 8) {
u32x8 ai = load_u32x8(data + i);
store_u32x8(data + i, mts.shrink(ai));
}
}
}
struct Cum_info {
alignas(64) u64x4 w_cum[LG], w_cum_r[LG];
};
static constexpr int MUL_HIGH_RADIX = 16;
static constexpr int MUL_ITER = 10;
static constexpr int MUL_QUAD = 4;
template <int X, int lg, bool mul_by_fc = true>
__attribute__((optimize("O3"))) void aux_mod_mul_x(u32 *__restrict__ a, u32 *__restrict__ b, std::array<u32, X> w, u32 fc, const Montgomery_simd &mts) const {
static_assert(3 <= lg);
constexpr int sz = 1 << lg;
alignas(64) u64 aux_a[X][sz * 2];
u32x8 fc_x8 = set1_u32x8(mt.mul<true>(fc, mt.r2));
const u32x8 perm_0 = (u32x8)setr_u64x4(0, 1, 2, 3);
const u32x8 perm_1 = (u32x8)setr_u64x4(4, 5, 6, 7);
for (int it = 0; it < X; it++) {
for (int i = 0; i < sz; i += 8) {
u64x4 w_x4 = (u64x4)set1_u32x8(w[it]);
{
u32x8 vec = load_u32x8(b + sz * it + i);
if constexpr (mul_by_fc) {
vec = mts.mul<true>(vec, fc_x8);
} else {
vec = mts.shrink(mts.shrink2(vec));
}
store_u32x8(b + sz * it + i, vec);
}
u32x8 vec = load_u32x8(a + sz * it + i);
vec = mts.shrink(mts.shrink2(vec));
u64x4 a_lw = (u64x4)permute_u32x8(vec, perm_0);
u64x4 a_hg = (u64x4)permute_u32x8(vec, perm_1);
u64x4 aw_lw = mts.mul<true>(a_lw, w_x4);
u64x4 aw_hg = mts.mul<true>(a_hg, w_x4);
store_u64x4(aux_a[it] + i + 0, aw_lw);
store_u64x4(aux_a[it] + i + 4, aw_hg);
store_u64x4(aux_a[it] + i + sz + 0, a_lw);
store_u64x4(aux_a[it] + i + sz + 4, a_hg);
}
}
// 6 + 25 = 31
u64x4 mod2_8 = set1_u64x4(mod * 1ULL * mod * 8);
auto reduce_cum = [&](u64x4 val) {
val -= (u64x4)_mm256_cmpgt_epi64(_mm256_setzero_si256(), (i256)val) & mod2_8;
return val;
};
alignas(64) u64x4 aux[X * sz / 4];
memset(aux, 0, sizeof(aux));
for (int i = 0; i < sz; i++) {
if (i >= 16 && i % 8 == 0) {
for (int it = 0; it < X; it++) {
for (int j = 0; j < sz / 4; j++) {
aux[it * sz / 4 + j] = reduce_cum(aux[it * sz / 4 + j]);
}
}
}
for (int it = 0; it < X; it++) {
for (int j = 0; j < sz; j += 4) {
u64x4 bi = (u64x4)set1_u32x8(b[i + sz * it]);
u64x4 aj = loadu_u64x4(aux_a[it] + sz - i + j);
aux[it * sz / 4 + j / 4] += (u64x4)mul64_u32x8((u32x8)bi, (u32x8)aj);
}
}
}
// 2 + 6 + 10 + 4 = 22
for (int i = 0; i < sz * X; i += 8) {
u64x4 a0 = aux[i / 4];
u64x4 a1 = aux[i / 4 + 1];
if constexpr (lg >= 4) {
a0 = reduce_cum(a0);
a1 = reduce_cum(a1);
}
u32x8 b = permute_u32x8(mts.reduce(a0, a1), setr_u32x8(0, 2, 4, 6, 1, 3, 5, 7));
b = mts.shrink(mts.shrink2(b));
store_u32x8(a + i, b);
}
}
// multiplies mod x^2^lg - w
// writes result to a
// input in [0, 4 * mod)
// output in [0, 2 * mod)
// multiplies everything by fc
template <bool mul_by_fc = true>
__attribute__((optimize("O3"))) inline void mul_mod(int lg, u32 *__restrict__ a, u32 *__restrict__ b, u32 w, u32 fc, const Montgomery_simd &mts) const {
const auto mt = this->mt; // ! to put Montgomery constants in registers
if (lg >= 3) {
assert(lg <= 12);
static_for<10>([&](auto i) {
if (i + 3 == lg) {
aux_mod_mul_x<1, i + 3, mul_by_fc>(a, b, std::array<u32, 1>{w}, fc, mts);
}
});
} else {
u32 buf[1 << 6];
memset(buf, 0, sizeof(buf));
for (int i = 0; i < (1 << lg); i++) {
a[i] = mt.shrink2(a[i]);
b[i] = mt.shrink2(b[i]);
}
for (int i = 0; i < (1 << lg); i++) {
for (int j = 0; j < (1 << lg); j++) {
int ind = i + j;
u32 pr = mt.mul(a[i], b[j]);
if (ind >= (1 << lg)) {
pr = mt.mul(pr, w);
ind -= 1 << lg;
}
buf[ind] = mt.shrink2(buf[ind] + pr);
}
}
if constexpr (mul_by_fc) {
fc = mt.mul<true>(fc, mt.r2);
for (int i = 0; i < (1 << lg); i++) {
buf[i] = mt.mul<true>(buf[i], fc);
}
} else {
for (int i = 0; i < (1 << lg); i++) {
buf[i] = mt.shrink(mt.shrink2(buf[i]));
}
}
memcpy(a, buf, 4 << lg);
}
}
template <bool mul_by_fc = true>
__attribute__((optimize("O3"))) inline void mul_mod_x4(int lg, u32 *__restrict__ a, u32 *__restrict__ b,
std::array<u32, 4> w, u32 fc, const Montgomery_simd &mts) const {
if (lg >= 3) {
assert(lg <= 12);
static_for<10>([&](auto i) {
if (i + 3 == lg) {
aux_mod_mul_x<4, i + 3, mul_by_fc>(a, b, w, fc, mts);
}
});
} else {
for (int it = 0; it < 4; it++) {
mul_mod<mul_by_fc>(lg, a + (1 << lg) * it, b + (1 << lg) * it, w[it], fc, mts);
}
}
}
template <int lg, bool first = true>
__attribute__((optimize("O3"))) typename std::enable_if<(lg >= 0), void>::type ntt_mul_rec_aux(int ind, u32 *data_a, u32 *data_b, Cum_info &cum_info, u32 fc) const {
const auto mts = this->mts; // ! to put Montgomery constants in registers
const auto mt = this->mt;
if constexpr (lg <= MUL_ITER) {
int k = lg;
for (; k > MUL_QUAD; k -= 2) {
u64x4 wj_cum = first ? (u64x4)mts.r : cum_info.w_cum[k];
u32x8 w_1 = set1_u32x8(w[1]);
for (int i = 0; i < (1 << lg); i += (1 << k)) {
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_cum_x4[__builtin_ctz(~(ind + i >> k))]);
for (auto data : std::array{data_a, data_b}) {
for (int j = 0; j < (1 << k - 2); j += 8) {
butterfly_forward_x4(data + i + 0 * (1 << k - 2) + j,
data + i + 1 * (1 << k - 2) + j,
data + i + 2 * (1 << k - 2) + j,
data + i + 3 * (1 << k - 2) + j,
w_1, w1, w2, w3, mts);
}
}
if (k - 2 <= MUL_QUAD) {
u32 f0 = w1[0];
u32 f1 = mt.mul<true>(f0, w[1]);
std::array<u32, 4> wi = {f0, mod - f0, f1, mod - f1};
mul_mod_x4<false>(k - 2, data_a + i, data_b + i, wi, fc, mts);
}
}
cum_info.w_cum[k] = wj_cum;
}
fc = mt.mul<true>(fc, mt.r2);
u32x8 fc_x8 = set1_u32x8(fc);
for (; k + 1 < lg; k += 2) {
u64x4 wj_cum = cum_info.w_cum_r[k];
if (first) {
wj_cum = (u64x4)mts.r;
if (k <= MUL_QUAD) {
wj_cum = (u64x4)fc_x8;
}
}
u32x8 w_1 = set1_u32x8(w_r[1]);
for (int i = 0; i < (1 << lg); i += (1 << k + 2)) {
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_rcum_x4[__builtin_ctz(~(ind + i >> k + 2))]);
if (k <= MUL_QUAD) {
for (int j = 0; j < (1 << k); j += 8) {
butterfly_inverse_x4<false, true>(data_a + i + 0 * (1 << k) + j,
data_a + i + 1 * (1 << k) + j,
data_a + i + 2 * (1 << k) + j,
data_a + i + 3 * (1 << k) + j,
w_1, w1, w2, w3, mts, fc_x8);
}
continue;
}
for (int j = 0; j < (1 << k); j += 8) {
butterfly_inverse_x4(data_a + i + 0 * (1 << k) + j,
data_a + i + 1 * (1 << k) + j,
data_a + i + 2 * (1 << k) + j,
data_a + i + 3 * (1 << k) + j,
w_1, w1, w2, w3, mts);
}
}
cum_info.w_cum_r[k] = wj_cum;
}
if constexpr (first) {
for (int i = 0; i < (1 << lg); i += 8) {
u32x8 a = load_u32x8(data_a + i);
store_u32x8(data_a + i, mts.shrink(a));
}
}
} else if (lg < MUL_HIGH_RADIX || 1) {
{
int k = lg;
u64x4 wj_cum = first ? (u64x4)mts.r : cum_info.w_cum[k];
u32x8 w_1 = set1_u32x8(w[1]);
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_cum_x4[__builtin_ctz(~(ind >> k))]);
for (auto data : std::array{data_a, data_b}) {
for (int j = 0; j < (1 << k - 2); j += 8) {
butterfly_forward_x4<first>(data + 0 * (1 << k - 2) + j,
data + 1 * (1 << k - 2) + j,
data + 2 * (1 << k - 2) + j,
data + 3 * (1 << k - 2) + j,
w_1, w1, w2, w3, mts);
}
}
cum_info.w_cum[k] = wj_cum;
}
{
constexpr int rd = 2;
static_for<1 << rd>([&](auto i) {
ntt_mul_rec_aux<lg - rd, first && i == 0>(ind + (1 << lg - rd) * i,
data_a + (1 << lg - rd) * i,
data_b + (1 << lg - rd) * i,
cum_info, fc);
});
}
{
int k = lg;
u64x4 wj_cum = first ? (u64x4)mts.r : cum_info.w_cum_r[k];
u32x8 w_1 = set1_u32x8(w_r[1]);
u32x8 w1 = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
u32x8 w2 = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
u32x8 w3 = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_rcum_x4[__builtin_ctz(~(ind >> k))]);
for (int j = 0; j < (1 << k - 2); j += 8) {
butterfly_inverse_x4<first>(data_a + 0 * (1 << k - 2) + j,
data_a + 1 * (1 << k - 2) + j,
data_a + 2 * (1 << k - 2) + j,
data_a + 3 * (1 << k - 2) + j,
w_1, w1, w2, w3, mts);
}
cum_info.w_cum_r[k] = wj_cum;
}
} else {
constexpr int RD = 6;
constexpr int BLC = 1024;
{
u32x8 w_1 = set1_u32x8(w[1]);
u32x8 wx_0[3];
u32x8 wx_1[4][3];
u32x8 wx_2[16][3];
for (int t = 0; t < 3; t++) {
u64x4 wj_cum = first ? (u64x4)mts.r : cum_info.w_cum[lg - 2 * t];
for (int i = 0; i < (1 << 2 * t); i++) {
u32x8 *wi = std::array{wx_0, wx_1[i], wx_2[i]}[t];
wi[0] = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
wi[1] = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
wi[2] = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_cum_x4[__builtin_ctz(~((ind >> lg - 2 * t) + i))]);
}
cum_info.w_cum[lg - 2 * t] = wj_cum;
}
for (auto data : std::array{data_a, data_b}) {
for (int t = 0; t <= 2; t++) {
const int sz = 1 << lg - 2 * t - 2;
for (int j0 = 0; j0 < sz; j0 += BLC) {
for (int i = 0; i < (1 << lg); i += 4 * sz) {
u32x8 *wi = std::array{wx_0, wx_1[i / sz / 4], wx_2[i / sz / 4]}[t];
u32x8 w1 = wi[0], w2 = wi[1], w3 = wi[2];
for (int j = j0; j < j0 + BLC; j += 8) {
if (first && i == 0) {
butterfly_forward_x4<true>(data + i + j + 0 * sz,
data + i + j + 1 * sz,
data + i + j + 2 * sz,
data + i + j + 3 * sz,
w_1, w1, w2, w3, mts);
} else {
butterfly_forward_x4(data + i + j + 0 * sz,
data + i + j + 1 * sz,
data + i + j + 2 * sz,
data + i + j + 3 * sz,
w_1, w1, w2, w3, mts);
}
}
}
}
}
}
}
{
static_for<1 << RD>([&](auto i) {
ntt_mul_rec_aux<lg - RD, first && i == 0>(ind + (1 << lg - RD) * i,
data_a + (1 << lg - RD) * i,
data_b + (1 << lg - RD) * i,
cum_info, fc);
});
}
{
u32x8 w_1 = set1_u32x8(w_r[1]);
u32x8 wx_0[3];
u32x8 wx_1[4][3];
u32x8 wx_2[16][3];
for (int t = 0; t < 3; t++) {
u64x4 wj_cum = first ? (u64x4)mts.r : cum_info.w_cum_r[lg - 2 * t];
for (int i = 0; i < (1 << 2 * t); i++) {
u32x8 *wi = std::array{wx_0, wx_1[i], wx_2[i]}[t];
wi[0] = shuffle_u32x8<0b00'00'00'00>((u32x8)wj_cum);
wi[1] = permute_u32x8((u32x8)wj_cum, set1_u32x8(2));
wi[2] = permute_u32x8((u32x8)wj_cum, set1_u32x8(6));
wj_cum = mts.mul<true>(wj_cum, w_rcum_x4[__builtin_ctz(~((ind >> lg - 2 * t) + i))]);
}
cum_info.w_cum_r[lg - 2 * t] = wj_cum;
}
for (int t = 2; t >= 0; t--) {
const int sz = 1 << lg - 2 * t - 2;
for (int j0 = 0; j0 < sz; j0 += BLC) {
for (int i = 0; i < (1 << lg); i += 4 * sz) {
u32x8 *wi = std::array{wx_0, wx_1[i / sz / 4], wx_2[i / sz / 4]}[t];
u32x8 w1 = wi[0], w2 = wi[1], w3 = wi[2];
for (int j = j0; j < j0 + BLC; j += 8) {
if (first && i == 0) {
butterfly_inverse_x4<true>(data_a + i + j + 0 * sz,
data_a + i + j + 1 * sz,
data_a + i + j + 2 * sz,
data_a + i + j + 3 * sz,
w_1, w1, w2, w3, mts);
} else {
butterfly_inverse_x4(data_a + i + j + 0 * sz,
data_a + i + j + 1 * sz,
data_a + i + j + 2 * sz,
data_a + i + j + 3 * sz,
w_1, w1, w2, w3, mts);
}
}
}
}
}
}
}
}
template <int lg, bool first>
typename std::enable_if<(lg < 0), void>::type ntt_mul_rec_aux(auto... args) const { ; }
// a and b should be 32-byte aligned
// writes (a * b) to a
[[gnu::noinline]] __attribute__((optimize("O3"))) void convolve(int lg, u32 *__restrict__ a, u32 *__restrict__ b) const {
if (lg <= 4) {
int n = (1 << lg);
u32 *__restrict__ c = (u32 *)_mm_malloc(n * 4, 4);
memset(c, 0, 4 * n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[(i + j) & (n - 1)] = mt.shrink2(c[(i + j) & (n - 1)] + mt.mul(a[i], b[j]));
}
}
for (int i = 0; i < n; i++) {
a[i] = mt.mul<true>(mt.r2, c[i]);
}
_mm_free(c);
return;
}
ntt(lg, a);
ntt(lg, b);
const auto mts = this->mts; // ! to put Montgomery constants in registers
for (int i = 0; i < (1 << lg); i += 8) {
u32x8 ai = load_u32x8(a + i), bi = load_u32x8(b + i);
store_u32x8(a + i, mts.mul(mts.shrink2(ai), mts.shrink2(bi)));
}
intt(lg, a, mt.r2);
}
// a and b should be 32-byte aligned
// writes (a * b) to a
[[gnu::noinline]] __attribute__((optimize("O3"))) void convolve2(int lg, u32 *__restrict__ a, u32 *__restrict__ b) const {
if (lg <= MUL_QUAD) {
mul_mod(lg, a, b, mt.r, mt.r, mts);
for (int i = 0; i < (1 << lg); i++) {
a[i] = mt.shrink(mt.shrink2(a[i]));
}
return;
}
Cum_info cum_info;
u32 f = power(mt.mul(mt.r2, mod + 1 >> 1), (lg - (MUL_QUAD - 1)) / 2 * 2);
static_for<LG - MUL_QUAD>([&](auto it) {
if (it + MUL_QUAD + 1 == lg)
ntt_mul_rec_aux<it + MUL_QUAD + 1>(0, a, b, cum_info, f);
});
}
__attribute__((optimize("O3"))) std::vector<u32> convolve_slow(std::vector<u32> a, std::vector<u32> b) const {
int sz = std::max(0, (int)a.size() + (int)b.size() - 1);
const auto mt = this->mt; // ! to put Montgomery constants in registers
std::vector<u32> c(sz);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) {
// c[i + j] = (c[i + j] + u64(a[i]) * b[j]) % mod;
c[i + j] = mt.shrink(c[i + j] + mt.mul<true>(mt.r2, mt.mul(a[i], b[j])));
}
}
return c;
}
__attribute__((optimize("O3"))) std::vector<u32> convolve(const std::vector<u32> &a, const std::vector<u32> &b) const {
int sz = std::max(0, (int)a.size() + (int)b.size() - 1);
int lg = std::__lg(std::max(1, sz - 1)) + 1;
u32 *ap = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
u32 *bp = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
memset(ap, 0, 4 << lg);
memset(bp, 0, 4 << lg);
std::copy(a.begin(), a.end(), ap);
std::copy(b.begin(), b.end(), bp);
convolve(lg, ap, bp);
std::vector<u32> res(ap, ap + sz);
_mm_free(ap);
_mm_free(bp);
return res;
}
__attribute__((optimize("O3"))) std::vector<u32> convolve2(const std::vector<u32> &a, const std::vector<u32> &b) const {
int sz = std::max(0, (int)a.size() + (int)b.size() - 1);
int lg = std::__lg(std::max(1, sz - 1)) + 1;
u32 *ap = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
u32 *bp = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
memset(ap, 0, 4 << lg);
memset(bp, 0, 4 << lg);
std::copy(a.begin(), a.end(), ap);
std::copy(b.begin(), b.end(), bp);
convolve2(lg, ap, bp);
std::vector<u32> res(ap, ap + sz);
_mm_free(ap);
_mm_free(bp);
return res;
}
};
#include <sys/mman.h>
#include <sys/stat.h>
#include <cstring>
#include <iostream>
// io from https://judge.yosupo.jp/submission/142782
namespace __io {
using u32 = uint32_t;
using u64 = uint64_t;
namespace QIO_base {
constexpr int O_buffer_default_size = 1 << 18;
constexpr int O_buffer_default_flush_threshold = 40;
struct _int_to_char_tab {
char tab[40000];
constexpr _int_to_char_tab() : tab() {
for (int i = 0; i != 10000; ++i) {
for (int j = 3, n = i; ~j; --j) {
tab[i * 4 + j] = n % 10 + 48, n /= 10;
}
}
}
} constexpr _otab;
} // namespace QIO_base
namespace QIO_I {
using namespace QIO_base;
struct Qinf {
FILE *f;
char *bg, *ed, *p;
struct stat Fl;
Qinf(FILE *fi) : f(fi) {
int fd = fileno(f);
fstat(fd, &Fl);
bg = (char *)mmap(0, Fl.st_size + 1, PROT_READ, MAP_PRIVATE, fd, 0);
p = bg, ed = bg + Fl.st_size;
madvise(p, Fl.st_size + 1, MADV_SEQUENTIAL);
}
~Qinf() { munmap(bg, Fl.st_size + 1); }
void skip_space() {
while (*p <= ' ') {
++p;
}
}
char get() { return *p++; }
char seek() { return *p; }
bool eof() { return p == ed; }
Qinf &read(char *s, size_t count) { return memcpy(s, p, count), p += count, *this; }
Qinf &operator>>(u32 &x) {
skip_space(), x = 0;
for (; *p > ' '; ++p) {
x = x * 10 + (*p & 0xf);
}
return *this;
}
Qinf &operator>>(int &x) {
skip_space();
if (*p == '-') {
for (++p, x = 48 - *p++; *p > ' '; ++p) {
x = x * 10 - (*p ^ 48);
}
} else {
for (x = *p++ ^ 48; *p > ' '; ++p) {
x = x * 10 + (*p ^ 48);
}
}
return *this;
}
} qin(stdin);
} // namespace QIO_I
namespace QIO_O {
using namespace QIO_base;
struct Qoutf {
FILE *f;
char *bg, *ed, *p;
char *ed_thre;
int fp;
u64 _fpi;
Qoutf(FILE *fo, size_t sz = O_buffer_default_size) : f(fo), bg(new char[sz]), ed(bg + sz), p(bg), ed_thre(ed - O_buffer_default_flush_threshold), fp(6), _fpi(1000000ull) {}
void flush() { fwrite_unlocked(bg, 1, p - bg, f), p = bg; }
void chk() {
if (__builtin_expect(p > ed_thre, 0)) {
flush();
}
}
~Qoutf() {
flush();
delete[] bg;
}
void put4(u32 x) {
if (x > 99u) {
if (x > 999u) {
memcpy(p, _otab.tab + (x << 2) + 0, 4), p += 4;
} else {
memcpy(p, _otab.tab + (x << 2) + 1, 3), p += 3;
}
} else {
if (x > 9u) {
memcpy(p, _otab.tab + (x << 2) + 2, 2), p += 2;
} else {
*p++ = x ^ 48;
}
}
}
void put2(u32 x) {
if (x > 9u) {
memcpy(p, _otab.tab + (x << 2) + 2, 2), p += 2;
} else {
*p++ = x ^ 48;
}
}
Qoutf &write(const char *s, size_t count) {
if (count > 1024 || p + count > ed_thre) {
flush(), fwrite_unlocked(s, 1, count, f);
} else {
memcpy(p, s, count), p += count, chk();
}
return *this;
}
Qoutf &operator<<(char ch) { return *p++ = ch, *this; }
Qoutf &operator<<(u32 x) {
if (x > 99999999u) {
put2(x / 100000000u), x %= 100000000u;
memcpy(p, _otab.tab + ((x / 10000u) << 2), 4), p += 4;
memcpy(p, _otab.tab + ((x % 10000u) << 2), 4), p += 4;
} else if (x > 9999u) {
put4(x / 10000u);
memcpy(p, _otab.tab + ((x % 10000u) << 2), 4), p += 4;
} else {
put4(x);
}
return chk(), *this;
}
Qoutf &operator<<(int x) {
if (x < 0) {
*p++ = '-', x = -x;
}
return *this << static_cast<u32>(x);
}
} qout(stdout);
} // namespace QIO_O
namespace QIO {
using QIO_I::qin;
using QIO_I::Qinf;
using QIO_O::qout;
using QIO_O::Qoutf;
} // namespace QIO
using namespace QIO;
}; // namespace __io
using namespace __io;
#include <cassert>
#include <iostream>
using namespace std;
NTT fft_(998'244'353);
void fast_mult(int n, int m, vector<int> &aa, const vector<int> &bb) {
int lg = std::__lg(std::max(1, n + m - 2)) + 1;
u32 *a = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
u32 *b = (u32 *)_mm_malloc(std::max(32, (1 << lg) * 4), 32);
for (int i = 0; i < n; i++) {
a[i] = aa[i];
}
memset(a + n, 0, (4 << lg) - 4 * n);
for (int i = 0; i < m; i++) {
b[i] = bb[i];
}
memset(b + m, 0, (4 << lg) - 4 * m);
fft_.convolve2(lg, a, b);
for (int i = 0; i < n; i++) {
aa[i] = a[i];
}
}
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
#define all(a) begin(a), end(a)
#define len(a) int((a).size())
#ifdef LOCAL
#include "debug.h"
#else
#define dbg(...)
#define dprint(...)
#define debug if constexpr (false)
#define draw_tree(...)
#define draw_array(...)
#endif // LOCAL
/*
! WARNING: MOD must be prime.
! WARNING: MOD must be less than 2^30.
* Use .get() to get the stored value.
*/
template<uint32_t mod>
class montgomery {
static_assert(mod < uint32_t(1) << 30, "mod < 2^30");
using mint = montgomery<mod>;
private:
uint32_t value;
static constexpr uint32_t inv_neg_mod = []() {
uint32_t x = mod;
for (int i = 0; i < 4; ++i) {
x *= uint32_t(2) - mod * x;
}
return -x;
}();
static_assert(mod * inv_neg_mod == -1);
static constexpr uint32_t neg_mod = (-uint64_t(mod)) % mod;
static uint32_t reduce(const uint64_t &value) {
return (value + uint64_t(uint32_t(value) * inv_neg_mod) * mod) >> 32;
}
inline static const mint ONE = mint(1);
public:
montgomery() : value(0) {}
montgomery(const mint &x) : value(x.value) {}
template<typename T, typename U = std::enable_if_t<std::is_integral<T>::value>>
montgomery(const T &x) : value(!x ? 0 : reduce(int64_t(x % int32_t(mod) + int32_t(mod)) * neg_mod)) {}
static constexpr uint32_t get_mod() {
return mod;
}
uint32_t get() const {
auto real_value = reduce(value);
return real_value < mod ? real_value : real_value - mod;
}
template<typename T>
mint power(T degree) const {
degree = (degree % int32_t(mod - 1) + int32_t(mod - 1)) % int32_t(mod - 1);
mint prod = 1, a = *this;
for (; degree > 0; degree >>= 1, a *= a)
if (degree & 1)
prod *= a;
return prod;
}
mint inv() const {
return power(-1);
}
mint& operator=(const mint &x) {
value = x.value;
return *this;
}
mint& operator+=(const mint &x) {
if (int32_t(value += x.value - (mod << 1)) < 0) {
value += mod << 1;
}
return *this;
}
mint& operator-=(const mint &x) {
if (int32_t(value -= x.value) < 0) {
value += mod << 1;
}
return *this;
}
mint& operator*=(const mint &x) {
value = reduce(uint64_t(value) * x.value);
return *this;
}
mint& operator/=(const mint &x) {
return *this *= x.inv();
}
friend mint operator+(const mint &x, const mint &y) {
return mint(x) += y;
}
friend mint operator-(const mint &x, const mint &y) {
return mint(x) -= y;
}
friend mint operator*(const mint &x, const mint &y) {
return mint(x) *= y;
}
friend mint operator/(const mint &x, const mint &y) {
return mint(x) /= y;
}
mint& operator++() {
return *this += ONE;
}
mint& operator--() {
return *this -= ONE;
}
mint operator++(int) {
mint prev = *this;
*this += ONE;
return prev;
}
mint operator--(int) {
mint prev = *this;
*this -= ONE;
return prev;
}
mint operator-() const {
return mint(0) - *this;
}
bool operator==(const mint &x) const {
return get() == x.get();
}
bool operator!=(const mint &x) const {
return get() != x.get();
}
bool operator<(const mint &x) const {
return get() < x.get();
}
template<typename T>
explicit operator T() {
return get();
}
friend std::istream& operator>>(std::istream &in, mint &x) {
std::string s;
in >> s;
x = 0;
bool neg = s[0] == '-';
for (const auto c : s)
if (c != '-')
x = x * 10 + (c - '0');
if (neg)
x *= -1;
return in;
}
friend std::ostream& operator<<(std::ostream &out, const mint &x) {
return out << x.get();
}
static int32_t primitive_root() {
if constexpr (mod == 1'000'000'007)
return 5;
if constexpr (mod == 998'244'353)
return 3;
if constexpr (mod == 786433)
return 10;
static int root = -1;
if (root != -1)
return root;
std::vector<int> primes;
int value = mod - 1;
for (int i = 2; i * i <= value; i++)
if (value % i == 0) {
primes.push_back(i);
while (value % i == 0)
value /= i;
}
if (value != 1)
primes.push_back(value);
for (int r = 2;; r++) {
bool ok = true;
for (auto p : primes)
if ((mint(r).power((mod - 1) / p)).get() == 1) {
ok = false;
break;
}
if (ok)
return root = r;
}
}
};
// constexpr uint32_t MOD = 1'000'000'007;
constexpr uint32_t MOD = 998'244'353;
using mint = montgomery<MOD>;
/*
! WARNING: MOD must be prime.
* Define modular int class above it.
* No need to run any init function, it dynamically resizes the data.
*/
namespace combinatorics {
std::vector<mint> fact_, ifact_, inv_;
void resize_data(int size) {
if (fact_.empty()) {
fact_ = {mint(1), mint(1)};
ifact_ = {mint(1), mint(1)};
inv_ = {mint(0), mint(1)};
}
for (int pos = fact_.size(); pos <= size; pos++) {
fact_.push_back(fact_.back() * mint(pos));
inv_.push_back(-inv_[MOD % pos] * mint(MOD / pos));
ifact_.push_back(ifact_.back() * inv_[pos]);
}
}
struct combinatorics_info {
std::vector<mint> &data;
combinatorics_info(std::vector<mint> &data) : data(data) {}
mint operator[](int pos) {
if (pos >= static_cast<int>(data.size())) {
resize_data(pos);
}
return data[pos];
}
} fact(fact_), ifact(ifact_), inv(inv_);
// From n choose k.
// O(max(n)) in total.
mint choose(int n, int k) {
if (n < k || k < 0 || n < 0) {
return mint(0);
}
return fact[n] * ifact[k] * ifact[n - k];
}
// From n choose k.
// O(min(k, n - k)).
mint choose_slow(int64_t n, int64_t k) {
if (n < k || k < 0 || n < 0) {
return mint(0);
}
k = std::min(k, n - k);
mint result = 1;
for (int i = k; i >= 1; i--) {
result *= (n - i + 1);
result *= inv[i];
}
return result;
}
// Number of balanced bracket sequences with n open and m closing brackets.
mint catalan(int n, int m) {
if (m > n || m < 0) {
return mint(0);
}
return choose(n + m, m) - choose(n + m, m - 1);
}
// Number of balanced bracket sequences with n open and closing brackets.
mint catalan(int n) {
return catalan(n, n);
}
} // namespace combinatorics
using namespace combinatorics;
constexpr int N = 8002;
mint f[N][N];
mint ans[N][N];
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
f[0][0] = 1;
for (int n = 1; n < N; n++) {
for (int k = 1; k <= n; k++) {
f[n][k] = k * f[n - 1][k] + k * f[n - 1][k - 1];
}
}
vector<int> p(N), q(N);
for (int d = 1; d < N; d++) {
q[d] = int(ifact[d + 1]);
}
for (int n = 1; n < N; n++) {
for (int k = 1; k < N; k++) {
p[k] = int(f[n][k] * ifact[k - 1]);
}
fast_mult(N, N, p, q);
for (int m = 1; m < N; m++) {
ans[n][m] = p[m] * fact[m];
}
}
int t;
cin >> t;
while (t--) {
int n, m;
cin >> n >> m;
mint cur = f[n][m] * m + ans[n][m];
cout << cur / mint(m).power(n) << '\n';
}
}