QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#370074 | #6513. Expression 3 | PlentyOfPenalty# | RE | 6ms | 5176kb | C++20 | 5.2kb | 2024-03-28 21:40:51 | 2024-03-28 21:40:53 |
Judging History
answer
#include <bits/stdc++.h>
#define all(x) begin(x), end(x)
using namespace std;
using ll = long long;
const int N = 2e5 + 9, mod = 998244353;
int n, a[N], fac[N], ifac[N];
string s;
constexpr int norm(int x) { return x < (mod >> 1) ? x : x - mod; }
constexpr int sub(int x, int y) {
x -= y;
return x < 0 ? x + mod : x;
}
constexpr int add(int x, int y) {
x += y;
return x >= mod ? x - mod : x;
}
constexpr int power(int a, int b) {
int s = 1;
for (; b; b >>= 1, a = (ll)a * a % mod)
if (b & 1) s = (ll)s * a % mod;
return s;
}
vector<int> rev, rt;
void getRevRoot(int n) {
int m = __lg(n); // log(n)/log(2)+1e-8;
rev.resize(n);
for (int i = 1; i < n; i++) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (m - 1);
}
static int len = 1;
if (len < n) {
rt.resize(n);
for (; len < n; len <<= 1) {
int uni = power(3, (mod - 1) / (len << 1));
rt[len] = 1;
for (int i = 1; i < len; i++) {
rt[i + len] = (ll)rt[i + len - 1] * uni % mod;
}
}
}
}
void ntt(vector<int> &f, int n) {
f.resize(n);
for (int i = 0; i < n; i++) {
if (i < rev[i]) swap(f[i], f[rev[i]]);
}
for (int len = 1; len < n; len *= 2) {
for (int i = 0; i < n; i += len * 2) {
for (int j = 0; j < len; j++) {
int x = f[i + j];
int y = (ll)f[i + j + len] * rt[j + len] % mod;
f[i + j] = add(x, y);
f[i + j + len] = sub(x, y);
}
}
}
}
vector<int> operator-(vector<int> f, const vector<int> &g) {
f.resize(max(f.size(), g.size()));
for (int i = 0; i < g.size(); i++) f[i] = sub(f[i], g[i]);
return f;
}
vector<int> operator*(vector<int> f, vector<int> g) {
if (f.empty()) return f;
if (g.empty()) return g;
int n = 1, m = (int)(f.size() + g.size()) - 1;
while (n < m) n <<= 1;
int invn = power(n, mod - 2);
getRevRoot(n), ntt(f, n), ntt(g, n);
for (int i = 0; i < n; i++) f[i] = (ll)f[i] * g[i] % mod;
reverse(f.begin() + 1, f.end()), ntt(f, n);
f.resize(m);
for (int i = 0; i < m; i++) f[i] = (ll)f[i] * invn % mod;
return f;
}
vector<int> inv(vector<int> f, int n) {
if (n == 1) return {power(f[0], mod - 2)};
f.resize(n);
vector<int> g = inv(f, n / 2), h(n);
g.resize(n);
for (int i = 0; i < n / 2; i++) h[i] = g[i];
int invn = power(n, mod - 2);
getRevRoot(n), ntt(f, n), ntt(g, n);
for (int i = 0; i < n; i++) f[i] = (ll)f[i] * g[i] % mod;
reverse(f.begin() + 1, f.end()), ntt(f, n);
for (int i = 1; i < n / 2; i++) f[i] = 0;
for (int i = n / 2; i < n; i++) f[i] = (ll)f[i] * invn % mod;
f[0] = 1, ntt(f, n);
for (int i = 0; i < n; i++) f[i] = (ll)f[i] * g[i] % mod;
reverse(f.begin() + 1, f.end()), ntt(f, n);
for (int i = n / 2; i < n; i++) h[i] = sub(0, (ll)f[i] * invn % mod);
return h;
}
vector<int> operator%(vector<int> a, vector<int> b) {
if (a.size() < b.size()) return a;
int len = (int)a.size() - (int)b.size() + 1;
if (len <= 0) return a;
vector<int> f = a;
std::reverse(f.begin(), f.end()), f.resize(len);
vector<int> g = b;
std::reverse(g.begin(), g.end()), g.resize(len);
vector<int> q = f * inv(g, g.size());
q.resize(len), std::reverse(q.begin(), q.end());
vector<int> r = a - q * b;
r.resize(b.size() - 1);
return r;
}
int g[N];
vector<int> p[N << 2], q[N << 2], f[N << 2];
void build(int u, int l, int r) {
if (l == r) {
p[u] = s[l] == '-' ? vector<int>{mod - l - 1, 1} : vector<int>{mod - l + 1, 1};
q[u] = {mod - l, 1};
} else {
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
p[u] = p[u << 1] * p[u << 1 | 1];
q[u] = q[u << 1] * q[u << 1 | 1];
}
// cerr << "build u=" << u << " l=" << l << " r=" << r << endl;
// cerr << "p[" << u << "]=";
// for (int x : p[u]) cerr << norm(x) << ",";
// cerr << endl;
// cerr << "q[" << u << "]=";
// for (int x : q[u]) cerr << norm(x) << ",";
// cerr << endl;
}
void solve(int u, int l, int r) {
// cerr << "solve " << u << " " << l << " " << r << endl;
if (l == r) {
g[l] = (long long)f[u][0] * (mod + l + p[u][0]) % mod;
return;
}
int mid = (l + r) >> 1;
// cerr << "handle left " << u << endl;
f[u << 1] = f[u] % q[u << 1];
// cerr << "handle left " << u << endl;
solve(u << 1, l, mid);
// cerr << "handle right " << u << endl;
f[u << 1 | 1] = (f[u] * p[u << 1]) % q[u << 1 | 1];
// cerr << "handle right " << u << endl;
solve(u << 1 | 1, mid + 1, r);
}
int main() {
#ifdef memset0
freopen("M.in", "r", stdin);
#endif
cin.tie(0)->sync_with_stdio(0);
fac[0] = ifac[0] = ifac[1] = 1;
for (int i = 2; i < N; i++) {
ifac[i] = (ll)(mod - mod / i) * ifac[mod % i] % mod;
}
for (int i = 1; i < N; i++) {
fac[i] = (ll)fac[i - 1] * i % mod;
ifac[i] = (ll)ifac[i - 1] * ifac[i] % mod;
}
cin >> n;
for (int i = 0; i < n; i++) {
cin >> a[i];
}
cin >> s;
s.insert(s.begin(), 0);
--n;
build(1, 1, n);
f[1] = {1};
solve(1, 1, n);
// for (int i = 1; i <= n; i++) cerr << norm(g[i]) << " \n"[i == n];
int ans = a[0];
for (int i = 1; i <= n; i++) {
ans = (ans + (long long)g[i] * a[i] % mod * ifac[i]) % mod;
}
ans = (long long)ans * fac[n] % mod;
cout << ans << endl;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 6ms
memory: 5176kb
input:
4 9 1 4 1 -+-
output:
46
result:
ok 1 number(s): "46"
Test #2:
score: 0
Accepted
time: 5ms
memory: 5164kb
input:
5 1 2 3 4 5 +-+-
output:
998244313
result:
ok 1 number(s): "998244313"
Test #3:
score: -100
Runtime Error
input:
100000 664815434 205025136 871445392 797947979 379688564 336946672 231295524 401655676 526374414 670533644 156882283 372427821 700299596 166140732 677498490 44858761 185182210 559696133 813911251 842364231 681916958 114039865 222372111 784286397 437994571 152137641 650875922 613727135 209302742 5321...