QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#449877 | #7561. Digit DP | pandapythoner | WA | 0ms | 3552kb | C++20 | 4.2kb | 2024-06-21 18:37:13 | 2024-06-21 18:37:13 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define lll __int128_t
#define ll long long
#define flt double
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define rep(i, n) for(int i = 0; i < n; i += 1)
mt19937 rnd(234);
const ll mod = 998244353;
istream& operator>>(istream& in, lll& a) {
string s;
in >> s;
int n = (int)s.size();
a = 0;
rep(i, n) if (s[i] == '1') a += (1 << (n - i - 1));
return in;
}
ll bin_pow(ll x, ll n) {
if (n == 0) {
return 1;
}
ll a = bin_pow(x, n / 2);
a = (a * a) % mod;
if (n & 1) {
a = (a * x) % mod;
}
return a;
}
ll rev(ll a) {
return bin_pow(a, mod - 2);
}
const ll rv2 = rev(2);
const ll rv3 = rev(3);
int n, q;
lll m;
vector<ll> a;
struct counter {
ll cnt = 0, s1 = 0, s2 = 0, s3 = 0;
counter() {}
counter(ll cnt, ll s1, ll s2, ll s3) : cnt(cnt% mod), s1(s1% mod), s2(s2% mod), s3(s3% mod) {}
void add(ll d) {
ll d2 = (d * d) % mod;
ll d3 = (d2 * d) % mod;
s3 = (s3 + 3 * s2 * d % mod + 3 * s1 * d2 % mod + cnt * d3 % mod) % mod;
s2 = (s2 + 2 * s1 * d % mod + cnt * d2 % mod) % mod;
s1 = (s1 + cnt * d) % mod;
}
ll get_triplets() {
ll rs = s1 * s1 % mod * s1 % mod;
rs = (rs + mod - s3) % mod;
ll bad = (s1 * s2 + mod - s3) % mod;
rs = (rs + 3 * mod - 3 * bad) % mod;
rs = rs * rv2 % mod * rv3 % mod;
return rs;
}
};
counter operator+(const counter& a, const counter& b) {
return counter(a.cnt + b.cnt, a.s1 + b.s1, a.s2 + b.s2, a.s3 + b.s3);
}
vector<counter> counters;
struct node {
int depth;
counter c;
ll d = 0;
int l = -1;
int r = -1;
node(int _depth) {
depth = _depth;
c = counters[depth];
}
void apply(ll x) {
c.add(x);
d = (d + x) % mod;
}
};
vector<node> t;
int new_node(int depth) {
int v = (int)t.size();
t.push_back(node(depth));
return v;
}
void upd(int v) {
if (t[v].l == -1 || t[v].r == -1) {
return;
}
auto x = t[t[v].l].c;
auto y = t[t[v].r].c;
y.add(a[t[v].depth]);
t[v].c = x + y;
}
void push(int v) {
assert(t[v].depth < n);
if (t[v].l == -1 || t[v].r == -1) {
t[v].l = new_node(t[v].depth + 1);
t[v].r = new_node(t[v].depth + 1);
}
t[t[v].l].apply(t[v].d);
t[t[v].r].apply(t[v].d);
t[v].d = 0;
}
counter get(int v, lll tl, lll tr, lll l, lll r) {
if (tr < l or r < tl) {
return counter();
}
if (l <= tl and tr <= r) {
return t[v].c;
}
lll tm = (tl + tr) / 2;
push(v);
auto x = get(t[v].l, tl, tm, l, r);
auto y = get(t[v].r, tl, tm, l, r);
y.add(a[t[v].depth]);
return x + y;
}
void add(int v, lll tl, lll tr, lll l, lll r, ll x) {
if (tr < l or r < tl) {
return;
}
if (l <= tl and tr <= r) {
t[v].apply(x);
return;
}
lll tm = (tl + tr) / 2;
push(v);
add(t[v].l, tl, tm, l, r, x);
add(t[v].r, tm + 1, tr, l, r, x);
upd(v);
}
int32_t main() {
cin >> n >> q;
a.resize(n);
for (int i = 0; i < n; i += 1) {
cin >> a[i];
}
reverse(all(a));
counters.resize(n + 1);
counters[n] = counter(1, 0, 0, 0);
for (int i = n - 1; i >= 0; i -= 1) {
auto x = counters[i + 1];
auto y = x;
y.add(a[i]);
counters[i] = x + y;
}
t = { node(0) };
m = 1;
for (int i = 0; i < n; i += 1) {
m *= 2;
}
for (int itr = 0; itr < q; itr += 1) {
int t;
cin >> t;
if (t == 1) {
lll l, r;
ll x;
cin >> l >> r >> x;
add(0, 0, m - 1, l, r, x);
} else if (t == 2) {
lll l, r;
cin >> l >> r;
auto f = get(0, 0, m - 1, l, r);
// cerr << f.cnt << " " << f.s1 << " " << f.s2 << " " << f.s3 << "\n";
cout << f.get_triplets() << "\n";
}
}
return 0;
}
/*
3 3
1 2 4
2 000 111
1 010 101 1
2 000 111
*/
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 0ms
memory: 3512kb
input:
3 3 1 2 4 2 000 111 1 010 101 1 2 000 111
output:
1960 3040
result:
ok 2 number(s): "1960 3040"
Test #2:
score: -100
Wrong Answer
time: 0ms
memory: 3552kb
input:
2 2 1 1 2 00 10 2 00 11
output:
2 2
result:
wrong answer 1st numbers differ - expected: '0', found: '2'