QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#449877#7561. Digit DPpandapythonerWA 0ms3552kbC++204.2kb2024-06-21 18:37:132024-06-21 18:37:13

Judging History

你现在查看的是最新测评结果

  • [2024-06-21 18:37:13]
  • 评测
  • 测评结果:WA
  • 用时:0ms
  • 内存:3552kb
  • [2024-06-21 18:37:13]
  • 提交

answer

#include <bits/stdc++.h>


using namespace std;

#define lll __int128_t
#define ll long long
#define flt double
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define rep(i, n) for(int i = 0; i < n; i += 1)


mt19937 rnd(234);
const ll mod = 998244353;


istream& operator>>(istream& in, lll& a) {
    string s;
    in >> s;
    int n = (int)s.size();
    a = 0;
    rep(i, n) if (s[i] == '1') a += (1 << (n - i - 1));
    return in;
}


ll bin_pow(ll x, ll n) {
    if (n == 0) {
        return 1;
    }
    ll a = bin_pow(x, n / 2);
    a = (a * a) % mod;
    if (n & 1) {
        a = (a * x) % mod;
    }
    return a;
}


ll rev(ll a) {
    return bin_pow(a, mod - 2);
}

const ll rv2 = rev(2);
const ll rv3 = rev(3);

int n, q;
lll m;
vector<ll> a;


struct counter {
    ll cnt = 0, s1 = 0, s2 = 0, s3 = 0;

    counter() {}
    counter(ll cnt, ll s1, ll s2, ll s3) : cnt(cnt% mod), s1(s1% mod), s2(s2% mod), s3(s3% mod) {}

    void add(ll d) {
        ll d2 = (d * d) % mod;
        ll d3 = (d2 * d) % mod;
        s3 = (s3 + 3 * s2 * d % mod + 3 * s1 * d2 % mod + cnt * d3 % mod) % mod;
        s2 = (s2 + 2 * s1 * d % mod + cnt * d2 % mod) % mod;
        s1 = (s1 + cnt * d) % mod;
    }


    ll get_triplets() {
        ll rs = s1 * s1 % mod * s1 % mod;
        rs = (rs + mod - s3) % mod;
        ll bad = (s1 * s2 + mod - s3) % mod;
        rs = (rs + 3 * mod - 3 * bad) % mod;
        rs = rs * rv2 % mod * rv3 % mod;
        return rs;
    }
};


counter operator+(const counter& a, const counter& b) {
    return counter(a.cnt + b.cnt, a.s1 + b.s1, a.s2 + b.s2, a.s3 + b.s3);
}


vector<counter> counters;


struct node {
    int depth;
    counter c;
    ll d = 0;
    int l = -1;
    int r = -1;

    node(int _depth) {
        depth = _depth;
        c = counters[depth];
    }

    void apply(ll x) {
        c.add(x);
        d = (d + x) % mod;
    }
};


vector<node> t;


int new_node(int depth) {
    int v = (int)t.size();
    t.push_back(node(depth));
    return v;
}

void upd(int v) {
    if (t[v].l == -1 || t[v].r == -1) {
        return;
    }
    auto x = t[t[v].l].c;
    auto y = t[t[v].r].c;
    y.add(a[t[v].depth]);
    t[v].c = x + y;
}


void push(int v) {
    assert(t[v].depth < n);
    if (t[v].l == -1 || t[v].r == -1) {
        t[v].l = new_node(t[v].depth + 1);
        t[v].r = new_node(t[v].depth + 1);
    }
    t[t[v].l].apply(t[v].d);
    t[t[v].r].apply(t[v].d);
    t[v].d = 0;
}

counter get(int v, lll tl, lll tr, lll l, lll r) {
    if (tr < l or r < tl) {
        return counter();
    }
    if (l <= tl and tr <= r) {
        return t[v].c;
    }
    lll tm = (tl + tr) / 2;
    push(v);
    auto x = get(t[v].l, tl, tm, l, r);
    auto y = get(t[v].r, tl, tm, l, r);
    y.add(a[t[v].depth]);
    return x + y;
}


void add(int v, lll tl, lll tr, lll l, lll r, ll x) {
    if (tr < l or r < tl) {
        return;
    }
    if (l <= tl and tr <= r) {
        t[v].apply(x);
        return;
    }
    lll tm = (tl + tr) / 2;
    push(v);
    add(t[v].l, tl, tm, l, r, x);
    add(t[v].r, tm + 1, tr, l, r, x);
    upd(v);
}

int32_t main() {
    cin >> n >> q;
    a.resize(n);
    for (int i = 0; i < n; i += 1) {
        cin >> a[i];
    }
    reverse(all(a));
    counters.resize(n + 1);
    counters[n] = counter(1, 0, 0, 0);
    for (int i = n - 1; i >= 0; i -= 1) {
        auto x = counters[i + 1];
        auto y = x;
        y.add(a[i]);
        counters[i] = x + y;
    }
    t = { node(0) };
    m = 1;
    for (int i = 0; i < n; i += 1) {
        m *= 2;
    }
    for (int itr = 0; itr < q; itr += 1) {
        int t;
        cin >> t;
        if (t == 1) {
            lll l, r;
            ll x;
            cin >> l >> r >> x;
            add(0, 0, m - 1, l, r, x);
        } else if (t == 2) {
            lll l, r;
            cin >> l >> r;
            auto f = get(0, 0, m - 1, l, r);
            // cerr << f.cnt << " " << f.s1 << " " << f.s2 << " " << f.s3 << "\n";
            cout << f.get_triplets() << "\n";
        }
    }
    return 0;
}

/*
3 3
1 2 4
2 000 111
1 010 101 1
2 000 111


*/

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 0ms
memory: 3512kb

input:

3 3
1 2 4
2 000 111
1 010 101 1
2 000 111

output:

1960
3040

result:

ok 2 number(s): "1960 3040"

Test #2:

score: -100
Wrong Answer
time: 0ms
memory: 3552kb

input:

2 2
1 1
2 00 10
2 00 11

output:

2
2

result:

wrong answer 1st numbers differ - expected: '0', found: '2'