QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#125912#6322. ForestryEnergy_is_not_overWA 0ms26868kbC++1710.8kb2023-07-17 21:50:162023-07-17 21:50:19

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2023-07-17 21:50:19]
  • 评测
  • 测评结果:WA
  • 用时:0ms
  • 内存:26868kb
  • [2023-07-17 21:50:16]
  • 提交

answer

//#pragma GCC optimize("Ofast", "unroll-loops")
//#pragma GCC target("sse", "sse2", "sse3", "ssse3", "sse4")

#ifdef __APPLE__
#include <iostream>
#include <cmath>
#include <algorithm>
#include <stdio.h>
#include <cstdint>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <bitset>
#include <map>
#include <queue>
#include <ctime>
#include <stack>
#include <set>
#include <list>
#include <random>
#include <deque>
#include <functional>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <complex>
#include <numeric>
#include <cassert>
#include <array>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#else
#include <bits/stdc++.h>
#endif

#define all(a) a.begin(),a.end()
#define len(a) (int)(a.size())
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
#define fi first
#define se second

using namespace std;

typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;

template<typename T>
bool umin(T &a, T b) {
    if (b < a) {
        a = b;
        return true;
    }
    return false;
}
template<typename T>
bool umax(T &a, T b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}

#if __APPLE__
#define D for (bool _FLAG = true; _FLAG; _FLAG = false)
#define LOG(...) print(#__VA_ARGS__" ::", __VA_ARGS__) << endl
template <class ...Ts> auto &print(Ts ...ts) { return ((cerr << ts << " "), ...); }
#else
#define D while (false)
#define LOG(...)
#endif

const int max_n = 3e5+10, inf = 1000111222;

const int md = 998244353;

void inc(int& a,int b)
{
    a+=b;
    if (a>=md){
        a-=md;
    }
}

int id_in_sorted[max_n];
int value_of_sorted[max_n];

struct segment_tree {
    int f[4 * max_n];
    int sum[4 * max_n];
    int cnt_non_zero[4 * max_n];
    int sum_value_cnt[4 * max_n];

    void init()
    {
        for (int i=0;i<4*max_n;i++){
            f[i]=1;
        }
    }

    void push(int v) {
        if (f[v] != 1) {
            sum[2 * v] = 1ll * sum[2 * v] * f[v] % md;
            sum[2 * v + 1] = 1ll * sum[2 * v + 1] * f[v] % md;
            sum_value_cnt[2 * v] = 1ll * sum_value_cnt[2 * v] * f[v] % md;
            sum_value_cnt[2 * v + 1] = 1ll * sum_value_cnt[2 * v + 1] * f[v] % md;
            f[2 * v] = 1ll * f[2 * v] * f[v] % md;
            f[2 * v + 1] = 1ll * f[2 * v + 1] * f[v] % md;
            f[v] = 1;
        }
    }

    void update_mult(int v, int tl, int tr, int l, int r, int value) {
        if (tl == l && tr == r) {
            f[v] = 1ll * f[v] * value % md;
            sum[v] = 1ll * sum[v] * value % md;
            sum_value_cnt[v] = 1ll * sum_value_cnt[v] * value % md;
            return;
        }
        int mid = (tl + tr) / 2;
        push(v);
        if (r <= mid) {
            update_mult(2 * v, tl, mid, l, r, value);
        } else if (l > mid) {
            update_mult(2 * v + 1, mid + 1, tr, l, r, value);
        } else {
            update_mult(2 * v, tl, mid, l, mid, value);
            update_mult(2 * v + 1, mid + 1, tr, mid + 1, r, value);
        }
        sum[v] = sum[2 * v] + sum[2 * v + 1];
        if (sum[v]>=md){
            sum[v]-=md;
        }
        sum_value_cnt[v] = sum_value_cnt[2 * v] + sum_value_cnt[2 * v + 1];
        if (sum_value_cnt[v]>=md){
            sum_value_cnt[v]-=md;
        }
    }

    void update_element(int v, int tl, int tr, int pos, int value) {
        if (tl == tr) {
            LOG("element_updated");
            sum[v] = value;
            sum_value_cnt[v] = 1ll * value_of_sorted[pos] * value % md;
            cnt_non_zero[v]++;
            return;
        }
        int mid = (tl + tr) / 2;
        push(v);
        if (pos <= mid) {
            update_element(2 * v, tl, mid, pos, value);
        } else {
            update_element(2 * v + 1, mid + 1, tr, pos, value);
        }
        sum[v] = sum[2 * v] + sum[2 * v + 1];
        if (sum[v]>=md){
            sum[v]-=md;
        }
        sum_value_cnt[v] = sum_value_cnt[2 * v] + sum_value_cnt[2 * v + 1];
        if (sum_value_cnt[v]>=md){
            sum_value_cnt[v]-=md;
        }
        cnt_non_zero[v] = cnt_non_zero[2 * v] + cnt_non_zero[2 * v + 1];
    }

    int get_sum(int v, int tl, int tr, int l, int r) {
        if (tl == l && tr == r) {
            return sum[v];
        }
        int mid = (tl + tr) / 2;
        push(v);
        if (r <= mid) {
            return get_sum(2 * v, tl, mid, l, r);
        } else if (l > mid) {
            return get_sum(2 * v + 1, mid + 1, tr, l, r);
        }
        int res = get_sum(2 * v, tl, mid, l, mid) + get_sum(2 * v + 1, mid + 1, tr, mid + 1, r);
        if (res>=md){
            res-=md;
        }
        return res;
    }

    int get_sum_value_cnt(int v, int tl, int tr, int l, int r) {
        if (tl == l && tr == r) {
            return sum_value_cnt[v];
        }
        int mid = (tl + tr) / 2;
        push(v);
        if (r <= mid) {
            return get_sum_value_cnt(2 * v, tl, mid, l, r);
        } else if (l > mid) {
            return get_sum_value_cnt(2 * v + 1, mid + 1, tr, l, r);
        }
        int res = get_sum_value_cnt(2 * v, tl, mid, l, mid) + get_sum_value_cnt(2 * v + 1, mid + 1, tr, mid + 1, r);
        if (res>=md){
            res-=md;
        }
        return res;
    }

    void build_into_vector_and_clear(int v,int tl,int tr,vector<pii>& res) {
        if (cnt_non_zero[v]==0){
            return;
        }
        cnt_non_zero[v]=0;
        if (tl == tr) {
            res.pb(mp(tl,sum[v]));
            sum[v]=0;
            sum_value_cnt[v] = 0;
            return;
        }
        int mid = (tl + tr) / 2;
        push(v);
        build_into_vector_and_clear(2 * v, tl, mid, res);
        build_into_vector_and_clear(2 * v + 1, mid + 1, tr, res);
        sum[v]=0;
        sum_value_cnt[v] = 0;
    }
};

segment_tree st;

int n;

int a[max_n];

vector<int> reb[max_n];

int sz[max_n];

void dfs0(int now,int pred)
{
    sz[now]=1;
    for (auto wh:reb[now]){
        if (wh!=pred){
            dfs0(wh,now);
            sz[now]+=sz[wh];
        }
    }
}

int pw2[max_n];

int ans[max_n];
vector<pii> dp[max_n];

void build_state_from_current_st(vector<pii>& dp)
{
    st.build_into_vector_and_clear(1,0,n-1,dp);
}

void merge_state_into_st(int v,int son)
{
    D{
        LOG("merge_state_into_st",v,son);
        cerr<<"dp :: ";
        for (auto i:dp[son]){
            cerr<<value_of_sorted[i.fir]<<","<<i.sec<<" ";
        }
        cerr<<"\n";
        cerr<<"dp raw indices :: ";
        for (auto i:dp[son]){
            cerr<<i.fir<<","<<i.sec<<" ";
        }
        cerr<<"\n";
    }
    inc(ans[v],1ll*ans[son]*pw2[(sz[v]-1)-(sz[son]-1)]%md);

    vector<int> ways_a(len(dp[son]));
    for (int ii=0;ii<len(dp[son]);ii++){
        pii i=dp[son][ii];
        if (i.fir!=n-1){
            ways_a[ii]=st.get_sum(1,0,n-1,i.fir+1,n-1);
        }
    }

    /// edge == 0
    {
        for (auto i:dp[son]){
            inc(ans[v],1ll*value_of_sorted[i.fir]*i.sec%md*pw2[(sz[v]-1)-(sz[son]-1)-1]%md);
        }
        LOG("edge == 0","finished");
    }

    /// edge == 1
    {
        int last_updated=n;
        reverse(all(dp[son]));
        int current_ways=pw2[sz[son]-1];
        for (auto i:dp[son]){
            LOG("edge == 1, doing first",i.fir,i.sec);
            if (i.fir+1<=last_updated-1){
                LOG("update_mult on subseg",i.fir+1,last_updated-1,current_ways);
                st.update_mult(1,0,n-1,i.fir+1,last_updated-1,current_ways);
                last_updated=i.fir+1;
            }
            inc(current_ways,i.sec);
        }
        if (0<=last_updated-1){
            LOG("update_mult on subseg",0,last_updated-1,current_ways);
            st.update_mult(1,0,n-1,0,last_updated-1,current_ways);
            last_updated=0;
        }
        reverse(all(dp[son]));
        LOG("edge == 1","finished");
    }


    for (int ii=0;ii<len(dp[son]);ii++){
        pii i=dp[son][ii];
        st.update_element(1,0,n-1,i.fir,1ll*ways_a[ii]*i.sec%md);
    }
}

void dfs1(int now,int pred)
{
    LOG("dfs1",now,pred);
    sort(all(reb[now]),[&](const int& lhs,const int& rhs){
        return sz[lhs]>sz[rhs];
    });
    if (pred!=-1){
        reb[now].erase(reb[now].begin());
    }

    if (len(reb[now])==0){
        LOG(now,"leaf",id_in_sorted[now]);
        st.update_element(1,0,n-1,id_in_sorted[now],1);
        return;
    }

    for (int i=1;i<len(reb[now]);i++){
        dfs1(reb[now][i],now);
        build_state_from_current_st(dp[reb[now][i]]);
    }

    dfs1(reb[now][0],now);
    /// update_root_into_st
    {
        const int v=now;
        const int son=reb[now][0];
        LOG("base v son",v,son);
        inc(ans[v],1ll*ans[son]*pw2[(sz[v]-1)-(sz[son]-1)]%md);

        {
            int st_sum_value_cnt=st.get_sum_value_cnt(1,0,n-1,0,n-1);
            LOG(v,son,st_sum_value_cnt);
            inc(ans[v],1ll*st_sum_value_cnt*pw2[(sz[v]-1)-(sz[son]-1)-1]%md);
            LOG("ans after this st_sum_value_cnt is ",ans[v]);
        }

        {
            int sum_greater_equal=st.get_sum(1,0,n-1,id_in_sorted[now],n-1);
            st.update_mult(1,0,n-1,id_in_sorted[now],n-1,0);
            LOG(v,(sum_greater_equal+pw2[sz[son]-1])%md);
            st.update_element(1,0,n-1,id_in_sorted[now],(sum_greater_equal+pw2[sz[son]-1])%md);
        }
    }

    LOG("before merges",now,ans[now]);

    for (int i=1;i<len(reb[now]);i++){
        merge_state_into_st(now,reb[now][i]);
        LOG("after merge with",reb[now][i],now,ans[now]);
    }

    LOG(now,ans[now]);
}

int main() {
    freopen("input.txt", "r", stdin);
//    freopen("output.txt", "w", stdout);

    ios_base::sync_with_stdio(0);
    cin.tie(0);

    pw2[0]=1;
    for (int i=1;i<max_n;i++){
        pw2[i]=1ll*pw2[i-1]*2%md;
    }

    cin>>n;
    vector<pii> aa(n);
    for (int i=0;i<n;i++){
        cin>>a[i];
        aa[i]=mp(a[i],i);
    }
    sort(all(aa));
    for (int i=0;i<n;i++){
        id_in_sorted[aa[i].sec]=i;
        value_of_sorted[i]=aa[i].fir;
    }
    for (int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        u--;
        v--;
        reb[u].pb(v);
        reb[v].pb(u);
    }
    dfs0(0,-1);
    st.init();
    dfs1(0,-1);

    int answer=(ans[0]+st.get_sum_value_cnt(1,0,n-1,0,n-1))%md;
    LOG(st.get_sum_value_cnt(1,0,n-1,0,n-1));
    D{
        vector<pii> dp0;
        build_state_from_current_st(dp0);
        cerr<<"dp0 :: "<<"\n";
        for (auto i:dp0){
            cerr<<value_of_sorted[i.fir]<<" "<<i.sec<<"\n";
        }
    };
    cout<<answer<<"\n";
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 0ms
memory: 26868kb

input:

4
1 2 3 4
1 2
2 4
3 2

output:

0

result:

wrong answer 1st numbers differ - expected: '44', found: '0'