QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#665408 | #6411. Classical FFT Problem | ohwphil | WA | 2ms | 5048kb | C++17 | 17.2kb | 2024-10-22 12:24:05 | 2024-10-22 12:24:06 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define all(x) (x).begin(), (x).end()
#define sz(x) (int)x.size()
typedef long long lint;
const long long mod = 119 << 23 | 1;
struct mint {
int val;
mint() { val = 0; }
mint(const lint& v) {
val = (-mod <= v && v < mod) ? v : v % mod;
if (val < 0) val += mod;
}
friend ostream& operator<<(ostream& os, const mint& a) { return os << a.val; }
friend bool operator==(const mint& a, const mint& b) { return a.val == b.val; }
friend bool operator!=(const mint& a, const mint& b) { return !(a == b); }
friend bool operator<(const mint& a, const mint& b) { return a.val < b.val; }
mint operator-() const { return mint(-val); }
mint& operator+=(const mint& m) { if ((val += m.val) >= mod) val -= mod; return *this; }
mint& operator-=(const mint& m) { if ((val -= m.val) < 0) val += mod; return *this; }
mint& operator*=(const mint& m) { val = (lint)val*m.val%mod; return *this; }
friend mint ipow(mint a, lint p) {
mint ans = 1; for (; p; p /= 2, a *= a) if (p&1) ans *= a;
return ans;
}
friend mint inv(const mint& a) { assert(a.val); return ipow(a, mod - 2); }
mint& operator/=(const mint& m) { return (*this) *= inv(m); }
friend mint operator+(mint a, const mint& b) { return a += b; }
friend mint operator-(mint a, const mint& b) { return a -= b; }
friend mint operator*(mint a, const mint& b) { return a *= b; }
friend mint operator/(mint a, const mint& b) { return a /= b; }
operator int64_t() const {return val; }
};
namespace fft{
using real_t = double;
using base = complex<real_t>;
void fft(vector<base> &a, bool inv){
int n = a.size(), j = 0;
vector<base> roots(n/2);
for(int i=1; i<n; i++){
int bit = (n >> 1);
while(j >= bit){
j -= bit;
bit >>= 1;
}
j += bit;
if(i < j) swap(a[i], a[j]);
}
real_t ang = 2 * acos(real_t(-1)) / n * (inv ? -1 : 1);
for(int i=0; i<n/2; i++){
roots[i] = base(cos(ang * i), sin(ang * i));
}
/*
XOR Convolution : set roots[*] = 1.
OR Convolution : set roots[*] = 1, and do following:
if (!inv) {
a[j + k] = u + v;
a[j + k + i/2] = u;
} else {
a[j + k] = v;
a[j + k + i/2] = u - v;
}
*/
for(int i=2; i<=n; i<<=1){
int step = n / i;
for(int j=0; j<n; j+=i){
for(int k=0; k<i/2; k++){
base u = a[j+k], v = a[j+k+i/2] * roots[step * k];
a[j+k] = u+v;
a[j+k+i/2] = u-v;
}
}
}
if(inv) for(int i=0; i<n; i++) a[i] /= n; // skip for OR convolution.
}
template<typename T>
void ntt(vector<T> &a, bool inv){
const int prr = 3; // primitive root
int n = a.size(), j = 0;
vector<T> roots(n/2);
for(int i=1; i<n; i++){
int bit = (n >> 1);
while(j >= bit){
j -= bit;
bit >>= 1;
}
j += bit;
if(i < j) swap(a[i], a[j]);
}
T ang = ipow(T(prr), (mod - 1) / n);
if(inv) ang = T(1) / ang;
for(int i=0; i<n/2; i++){
roots[i] = (i ? (roots[i-1] * ang) : T(1));
}
for(int i=2; i<=n; i<<=1){
int step = n / i;
for(int j=0; j<n; j+=i){
for(int k=0; k<i/2; k++){
T u = a[j+k], v = a[j+k+i/2] * roots[step * k];
a[j+k] = u+v;
a[j+k+i/2] = u-v;
}
}
}
if(inv){
T rev = T(1) / T(n);
for(int i=0; i<n; i++) a[i] *= rev;
}
}
template<typename T>
vector<T> multiply_ntt(vector<T> &v, const vector<T> &w){
vector<T> fv(all(v)), fw(all(w));
int n = 2;
while(n < sz(v) + sz(w)) n <<= 1;
fv.resize(n); fw.resize(n);
ntt(fv, 0); ntt(fw, 0);
for(int i=0; i<n; i++) fv[i] *= fw[i];
ntt(fv, 1);
vector<T> ret(n);
for(int i=0; i<n; i++) ret[i] = fv[i];
return ret;
}
template<typename T>
vector<T> multiply(vector<T> &v, const vector<T> &w){
vector<base> fv(all(v)), fw(all(w));
int n = 2;
while(n < sz(v) + sz(w)) n <<= 1;
fv.resize(n); fw.resize(n);
fft(fv, 0); fft(fw, 0);
for(int i=0; i<n; i++) fv[i] *= fw[i];
fft(fv, 1);
vector<T> ret(n);
for(int i=0; i<n; i++) ret[i] = (T)llround(fv[i].real());
return ret;
}
template<typename T>
vector<T> multiply_mod(vector<T> v, const vector<T> &w){
int n = 2;
while(n < sz(v) + sz(w)) n <<= 1;
vector<base> v1(n), v2(n), r1(n), r2(n);
for(int i=0; i<v.size(); i++){
v1[i] = base(v[i] >> 15, v[i] & 32767);
}
for(int i=0; i<w.size(); i++){
v2[i] = base(w[i] >> 15, w[i] & 32767);
}
fft(v1, 0);
fft(v2, 0);
for(int i=0; i<n; i++){
int j = (i ? (n - i) : i);
base ans1 = (v1[i] + conj(v1[j])) * base(0.5, 0);
base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5);
base ans3 = (v2[i] + conj(v2[j])) * base(0.5, 0);
base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5);
r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0, 1);
r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0, 1);
}
fft(r1, 1);
fft(r2, 1);
vector<T> ret(n);
for(int i=0; i<n; i++){
T av = llround(r1[i].real());
T bv = llround(r1[i].imag()) + llround(r2[i].real());
T cv = llround(r2[i].imag());
av = av << 30;
bv = bv << 15;
ret[i] = av + bv + cv;
}
return ret;
}
template<typename T>
vector<T> multiply_naive(vector<T> v, const vector<T> &w){
if(sz(v) == 0 || sz(w) == 0) return vector<T>();
vector<T> ret(sz(v) + sz(w) - 1);
for(int i=0; i<sz(v); i++){
for(int j=0; j<sz(w); j++){
ret[i + j] += v[i] * w[j];
}
}
return ret;
}
}
template<typename T>
struct poly {
vector<T> a;
void normalize() { // get rid of leading zeroes
while(!a.empty() && a.back() == T(0)) {
a.pop_back();
}
}
poly(){}
poly(T a0){ a = {a0}; normalize(); }
poly(vector<T> t) : a(t){ normalize(); }
int deg() const{ return sz(a) - 1; } // -1 if empty
T lead() const{ return sz(a) ? a.back() : T(0); }
T operator [](int idx) const {
return idx >= (int)a.size() || idx < 0 ? T(0) : a[idx];
}
T& coef(size_t idx) { // mutable reference at coefficient
return a[idx];
}
poly reversed() const{
vector<T> b = a;
reverse(all(b));
return poly(b);
}
poly trim(int n) const{
n = min(n, sz(a));
vector<T> b(a.begin(), a.begin() + n);
return poly(b);
}
poly operator *= (const T &x) {
for(auto &it: a) {
it *= x;
}
normalize();
return *this;
}
poly operator /= (const T &x) {
return *this *= (T(1)/ T(x));
}
poly operator * (const T &x) const {return poly(*this) *= x;}
poly operator / (const T &x) const {return poly(*this) /= x;}
poly operator+=(const poly &p){
a.resize(max(sz(a), sz(p.a)));
for(int i=0; i<sz(p.a); i++){
a[i] += p.a[i];
}
normalize();
return *this;
}
poly operator-=(const poly &p){
a.resize(max(sz(a), sz(p.a)));
for(int i=0; i<sz(p.a); i++){
a[i] -= p.a[i];
}
normalize();
return *this;
}
poly operator*=(const poly &p){
*this = poly(fft::multiply_mod(a, p.a));
normalize();
return *this;
}
poly inv(int n){
poly q(T(1) / a[0]);
for(int i=1; i<n; i<<=1){
poly p = poly(2) - q * trim(i * 2);
q = (p * q).trim(i * 2);
}
return q.trim(n);
}
pair<poly, poly> divmod_slow(const poly &b) const { // when divisor or quotient is small
vector<T> A(a);
vector<T> res;
while(A.size() >= b.a.size()) {
res.push_back(A.back() / b.a.back());
if(res.back() != T(0)) {
for(size_t i = 0; i < b.a.size(); i++) {
A[A.size() - i - 1] -= res.back() * b.a[b.a.size() - i - 1];
}
}
A.pop_back();
}
reverse(all(res));
return {res, A};
}
poly operator/=(const poly &b){
if(deg() < b.deg()) return *this = poly();
if(min(deg(), b.deg()) < 256) return *this = divmod_slow(b).first;
int k = deg() - b.deg() + 1;
poly ra = reversed().trim(k);
poly rb = b.reversed().trim(k).inv(k);
*this = (ra * rb).trim(k);
while(sz(a) < k) a.push_back(T(0));
reverse(all(a));
normalize();
return *this;
}
poly operator%=(const poly &b){
if(deg() < b.deg()) return *this;
if(min(deg(), b.deg()) < 256) return *this = divmod_slow(b).second;
poly foo = poly(a); foo /= b; foo *= b;
*this = poly(*this) -= foo;
normalize();
return *this;
}
poly operator+(const poly &p)const{ return poly(*this) += p; }
poly operator-(const poly &p)const{ return poly(*this) -= p; }
poly operator*(const poly &p)const{ return poly(*this) *= p; }
poly operator/(const poly &p)const{ return poly(*this) /= p; }
poly operator%(const poly &p)const{ return poly(*this) %= p; }
poly deriv() { // calculate derivative
vector<T> res;
for(int i = 1; i <= deg(); i++) {
res.push_back(T(i) * a[i]);
}
return res;
}
poly integr() { // calculate integral with C = 0
vector<T> res = {0};
for(int i = 0; i <= deg(); i++) {
res.push_back(a[i] / T(i + 1));
}
return res;
}
poly ln(int n){
assert(sz(a) > 0 && a[0] == T(1));
return (deriv() * inv(n)).integr().trim(n);
}
poly exp(int n){
if(sz(a) == 0){
return poly({T(1)});
}
assert(sz(a) > 0 && a[0] == T(0));
poly q(1);
for(int i=1; i<n; i<<=1){
poly p = poly(1) + trim(2 * i) - q.ln(2 * i);
q = (q * p).trim(2 * i);
}
return q.trim(n);
}
poly power(int n, int k){
if(sz(a) == 0) return poly();
if(k == 0) return poly(T(1)).trim(n);
if(k == 1) return trim(n);
int ptr = 0;
while(ptr < sz(a) && a[ptr] == T(0)) ptr++;
if(1ll * ptr * k >= n) return poly();
n -= ptr * k;
poly p(vector<T>(a.begin() + ptr, a.end()));
T coeff = a[ptr];
p /= coeff;
p = p.ln(n);
p *= k;
p = p.exp(n);
p *= ipow(coeff, k);
vector<T> q(ptr * k, T(0));
for(int i=0; i<=p.deg(); i++) q.push_back(p[i]);
return poly(q);
}
poly root(int n, int k = 2){
// NOT TESTED in K > 2
assert(sz(a) > 0 && a[0] == T(1) && k >= 2);
poly q(1);
for(int i=1; i<n; i<<=1){
if(k == 2) q += trim(2 * i) * q.inv(2 * i);
else q = q * T(k - 1) + trim(2 * i) * power(q.inv(2 * i), k - 1, 2 * i);
q = q.trim(2 * i) / T(k);
}
return q.trim(n);
}
};
// polynomial taylor shift
// return f(x + a)
poly<mint> taylor_shift(poly<mint>& f, mint a) {
poly<mint> ff = f;
reverse(ff.a.begin(), ff.a.end());
vector<mint> gg_vec(ff.a.size(), mint(1));
poly<mint> gg(gg_vec);
mint fac = 1;
for (int i = 1; i < ff.a.size(); i++) {
ff.a[ff.a.size() - i - 1] *= fac;
fac *= mint(i + 1);
}
for (int i = 1; i < ff.a.size(); i++) {
gg.a[i] = gg.a[i - 1] * a;
gg.a[i] /= mint(i);
}
poly<mint> res = (ff * gg).trim(ff.a.size());
res.a.resize(ff.a.size());
reverse(res.a.begin(), res.a.end());
mint inv_fac = mint(1) / fac;
for (int i = ff.a.size() - 1; i >= 0; i--) {
inv_fac *= mint(i + 1);
res.a[i] *= inv_fac;
}
res.a.resize(ff.a.size());
return res;
}
// multipoint evaluation
// return {f(x_0), f(x_1), ..., f(x_{n-1})}
vector<mint> multipoint_evaluation(poly<mint>& f, vector<mint>& x) {
int n = x.size();
int seg_n = 1;
while (seg_n < n) seg_n <<= 1;
// stores \prod_{l \leq i < r} (x - x_i)
vector<poly<mint>> mul_seg(2 * seg_n);
vector<poly<mint>> res_seg(2 * seg_n);
for (int i = 0; i < n; i++) {
mul_seg[seg_n + i] = poly<mint>({-x[i], 1});
}
for (int i = n; i < seg_n; i++) {
mul_seg[seg_n + i] = poly<mint>(1);
}
for (int i = seg_n - 1; i >= 1; i--) {
if (mul_seg[i * 2 + 1].deg() == 0) {
mul_seg[i] = mul_seg[i * 2];
}
else {
mul_seg[i] = mul_seg[i * 2] * mul_seg[i * 2 + 1];
}
}
res_seg[1] = f % mul_seg[1];
for (int i = 1; i < seg_n; i++) {
res_seg[i * 2] = res_seg[i] % mul_seg[i * 2];
res_seg[i * 2 + 1] = res_seg[i] % mul_seg[i * 2 + 1];
}
vector<mint> res(n);
for (int i = 0; i < n; i++) {
if (res_seg[seg_n + i].deg() == -1) {
res[i] = mint(0);
}
else {
res[i] = res_seg[seg_n + i].a[0];
}
}
return res;
}
// calculate product of multiple polynomials
poly<mint> multiply_polynomials(vector<poly<mint>>& polynomials) {
int n = polynomials.size();
int seg_n = 1;
while (seg_n < n) seg_n <<= 1;
vector<poly<mint>> seg(2 * seg_n);
for (int i = 0; i < n; i++) {
seg[seg_n + i] = polynomials[i];
}
for (int i = n; i < seg_n; i++) {
seg[seg_n + i] = poly<mint>(1);
}
for (int i = seg_n - 1; i >= 1; i--) {
if (seg[i * 2 + 1].deg() == 0) {
seg[i] = seg[i * 2];
}
else {
seg[i] = seg[i * 2] * seg[i * 2 + 1];
}
}
return seg[1];
}
int N;
vector<mint> points, dual_points;
mint factorial[131073];
mint inv_factorial[131073];
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
factorial[0] = 1;
for (int i = 1; i <= 131072; i++) {
factorial[i] = factorial[i - 1] * mint(i);
}
inv_factorial[131072] = mint(1) / factorial[131072];
for (int i = 131071; i >= 0; i--) {
inv_factorial[i] = inv_factorial[i + 1] * mint(i + 1);
}
cin >> N;
points.resize(N);
dual_points.resize(N);
for (int i = 0; i < N; i++) {
cin >> points[i].val;
}
int rook_cnt = 1;
while (rook_cnt <= N) {
if (rook_cnt >= points[N - rook_cnt].val) {
break;
}
++rook_cnt;
}
cout << rook_cnt << " ";
int curr_idx = N - 1;
for (int curr_width = N; curr_width > 0; --curr_width) {
while (curr_idx >= 0) {
if (points[curr_idx].val < curr_width) {
break;
}
--curr_idx;
}
dual_points[curr_width - 1] = N - 1 - curr_idx;
}
reverse(points.begin(), points.end());
vector<poly<mint>> polynomials(rook_cnt), dual_polynomials(rook_cnt);
for (int i = 0; i < rook_cnt; ++i) {
polynomials[i] = poly<mint>({points[i], -1});
dual_polynomials[i] = poly<mint>({dual_points[i], -1});
}
poly<mint> prod = multiply_polynomials(polynomials);
poly<mint> dual_prod = multiply_polynomials(dual_polynomials);
mint ans = mint(0);
int t = (rook_cnt == N ? 0 : points[rook_cnt].val);
vector<mint> points_t(t + 1);
for (int i = 0; i <= t; ++i) {
points_t[i] = mint(i);
}
vector<mint> values = multipoint_evaluation(prod, points_t);
for (int i = 0; i <= t; ++i) {
mint val = values[i];
if (i % 2) {
val = -val;
}
val *= factorial[t];
val *= inv_factorial[i];
val *= inv_factorial[t - i];
ans += val;
}
int dual_t = (rook_cnt == N ? 0 : dual_points[rook_cnt].val);
vector<mint> dual_points_t(dual_t + 1);
for (int i = 0; i <= dual_t; ++i) {
dual_points_t[i] = mint(i);
}
vector<mint> dual_values = multipoint_evaluation(dual_prod, dual_points_t);
for (int i = 0; i <= dual_t; ++i) {
mint val = dual_values[i];
if (i % 2) {
val = -val;
}
val *= factorial[dual_t];
val *= inv_factorial[i];
val *= inv_factorial[dual_t - i];
ans += val;
}
ans -= factorial[rook_cnt];
cout << ans << '\n';
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 2ms
memory: 4796kb
input:
3 1 2 3
output:
2 6
result:
ok 2 number(s): "2 6"
Test #2:
score: 0
Accepted
time: 0ms
memory: 4720kb
input:
1 1
output:
1 1
result:
ok 2 number(s): "1 1"
Test #3:
score: 0
Accepted
time: 2ms
memory: 4956kb
input:
2 1 1
output:
1 2
result:
ok 2 number(s): "1 2"
Test #4:
score: 0
Accepted
time: 2ms
memory: 4836kb
input:
2 2 2
output:
2 6
result:
ok 2 number(s): "2 6"
Test #5:
score: 0
Accepted
time: 2ms
memory: 4836kb
input:
3 1 1 1
output:
1 3
result:
ok 2 number(s): "1 3"
Test #6:
score: 0
Accepted
time: 2ms
memory: 4848kb
input:
3 2 2 2
output:
2 9
result:
ok 2 number(s): "2 9"
Test #7:
score: 0
Accepted
time: 0ms
memory: 4816kb
input:
3 3 3 3
output:
3 48
result:
ok 2 number(s): "3 48"
Test #8:
score: 0
Accepted
time: 2ms
memory: 4784kb
input:
5 1 1 3 3 4
output:
3 47
result:
ok 2 number(s): "3 47"
Test #9:
score: 0
Accepted
time: 2ms
memory: 5048kb
input:
10 2 4 5 5 5 5 6 8 8 10
output:
5 864
result:
ok 2 number(s): "5 864"
Test #10:
score: -100
Wrong Answer
time: 0ms
memory: 4848kb
input:
30 6 8 9 9 9 10 13 14 15 15 16 17 17 18 20 22 22 23 23 24 24 25 25 25 27 28 28 29 29 30
output:
18 383552179
result:
wrong answer 1st numbers differ - expected: '17', found: '18'