QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#500810#6411. Classical FFT Problemucup-team2307#WA 8ms23136kbC++204.1kb2024-08-01 20:58:062024-08-01 20:58:08

Judging History

你现在查看的是最新测评结果

  • [2024-08-01 20:58:08]
  • 评测
  • 测评结果:WA
  • 用时:8ms
  • 内存:23136kb
  • [2024-08-01 20:58:06]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
#define int ll
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
using pii = pair<int, int>;
using vi = vector<int>;
#define fi first
#define se second
#define pb push_back

const ll mod = (119 << 23) + 1, root = 62; // = 998244353
ll modpow(ll b, ll e) {
	ll ans = 1;
	for (; e; b = b * b % mod, e /= 2)
		if (e & 1) ans = ans * b % mod;
	return ans;
}
// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21
// and 483 << 21 (same root). The last two are > 10^9.
typedef vector<ll> vl;
void ntt(vl &a) {
	int n = sz(a), L = 31 - __builtin_clz(n);
	static vl rt(2, 1);
	for (static int k = 2, s = 2; k < n; k *= 2, s++) {
		rt.resize(n);
		ll z[] = {1, modpow(root, mod >> s)};
		rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
	}
	vi rev(n);
	rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
	rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int k = 1; k < n; k *= 2)
		for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
			ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
			a[i + j + k] = ai - z + (z > ai ? mod : 0);
			ai += (ai + z >= mod ? z - mod : z);
		}
}
vl conv(const vl &a, const vl &b) {
	if (a.empty() || b.empty()) return {};
	int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s),
	    n = 1 << B;
	int inv = modpow(n, mod - 2);
	vl L(a), R(b), out(n);
	L.resize(n), R.resize(n);
	ntt(L), ntt(R);
	rep(i,0,n)
		out[-i & (n - 1)] = (ll)L[i] * R[i] % mod * inv % mod;
	ntt(out);
	return {out.begin(), out.begin() + s};
}

const int N = 1e6 + 66;
int fact[N], ifact[N];
void precalc() {
	ifact[0] = fact[0] = ifact[1] = 1;
	for(int i = 2; i < N; i++) {
		ifact[i] = mod - (mod / i) * 1ll * ifact[mod % i] % mod;
	}
	for(int i = 1; i < N; i++) {
		fact[i] = fact[i - 1] * 1ll * i % mod;
		ifact[i] = ifact[i - 1] * 1ll * ifact[i] % mod;
	}
}
int nck(int n, int k) {
	if(k < 0 || k > n) return 0;
	int ans = fact[n] * 1ll * ifact[k] % mod;
	return ans * 1ll * ifact[n-k]%mod;
}
int bp(int a, int p) {
	int r = 1;
	for(; p; p>>=1, a = a*1ll*a%mod)
		if(p&1) r = r*1ll*a%mod;
	return r;
}
vl solve(vl &a, int l, int r) {
    if(l + 1 == r) {
        return {1, a[l]};
    }
    int m = (l + r) / 2;
    auto x = solve(a, l, m);
    auto y = solve(a, m, r);
    return conv(x, y);
}

vl stirling(int k, int N) {
    vl a(N + 1); for(int i = 1; i <= N; i++) a[i] = ifact[i];
    // a = polynomials::pow(a, k, n + 1);
    vl res {1};
    for(int t = k; t;) {
        if(t & 1) res = conv(res, a);
        res.resize(N + 1);
        a = conv(a, a);
        a.resize(N + 1);
        t >>=1;
    }
    a = res;

    for(int i = 0; i <= N; i++) a[i] = a[i] * 1ll* fact[i] % mod;
    return a;
}

int a[1000000];
int b[1000000];
int calc(int n, int l)
{
    int h = a[l+1];
    vector<int> v;
    for (int i=1; i<=l; i++)
        v.pb(a[i]-h);
    vector<int> u = solve(v, 0, v.size());
//    for (int i : v)
//        cout<<i<<" ";
//    cout<<"\n";
//    for (int i : u)
//        cout<<i<<" ";
//    cout<<"\n";
    int ans = 0;
    vector<int> stv = stirling(h, l+1);

//    cout<<h<<" : ";
//    for (int i=0; i<=l; i++)
//        cout<<stv[i]<<" ";
//    cout<<"\n";

    for (int i=h; i<=l; i++)
    {
        int cur = u[l-i];
        int st = stv[i];
        ans = (ans + cur*st)%mod;
    }
    return ans;
}

signed main()
{
    precalc();
	cin.tie(0)->sync_with_stdio(0);
	cin.exceptions(cin.failbit);
    cout<<fixed<<setprecision(10);

    int n;
    cin>>n;
    for (int i=1; i<=n; i++)
        cin>>a[i];
    reverse(a+1, a+n+1);
    int l = 1;
    while (a[l] >= l)
        l++;
    l--;
    for (int i=1; i<=n; i++)
        b[a[i]] = max(b[a[i]], i);
    for (int i=n-1; i>=1; i--)
        b[i] = max(b[i], b[i+1]);

    int a1 = calc(n, l);
    for (int i=1; i<=n; i++)
        swap(a[i], b[i]);
    int a2 = calc(n, l);
    int fact = 1;
    for (int i=2; i<=l; i++)
        fact = (fact*i)%mod;
//    cout<<a1<<" "<<a2<<" "<<fact<<"\n";
    cout<<l<<" "<<(a1+a2-fact+mod)%mod<<"\n";
}

詳細信息

Test #1:

score: 100
Accepted
time: 4ms
memory: 23136kb

input:

3
1 2 3

output:

2 6

result:

ok 2 number(s): "2 6"

Test #2:

score: -100
Wrong Answer
time: 8ms
memory: 21504kb

input:

1
1

output:

1 3

result:

wrong answer 2nd numbers differ - expected: '1', found: '3'