QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#125912 | #6322. Forestry | Energy_is_not_over | WA | 0ms | 26868kb | C++17 | 10.8kb | 2023-07-17 21:50:16 | 2023-07-17 21:50:19 |
Judging History
answer
//#pragma GCC optimize("Ofast", "unroll-loops")
//#pragma GCC target("sse", "sse2", "sse3", "ssse3", "sse4")
#ifdef __APPLE__
#include <iostream>
#include <cmath>
#include <algorithm>
#include <stdio.h>
#include <cstdint>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <bitset>
#include <map>
#include <queue>
#include <ctime>
#include <stack>
#include <set>
#include <list>
#include <random>
#include <deque>
#include <functional>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <complex>
#include <numeric>
#include <cassert>
#include <array>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#else
#include <bits/stdc++.h>
#endif
#define all(a) a.begin(),a.end()
#define len(a) (int)(a.size())
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
#define fi first
#define se second
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;
template<typename T>
bool umin(T &a, T b) {
if (b < a) {
a = b;
return true;
}
return false;
}
template<typename T>
bool umax(T &a, T b) {
if (a < b) {
a = b;
return true;
}
return false;
}
#if __APPLE__
#define D for (bool _FLAG = true; _FLAG; _FLAG = false)
#define LOG(...) print(#__VA_ARGS__" ::", __VA_ARGS__) << endl
template <class ...Ts> auto &print(Ts ...ts) { return ((cerr << ts << " "), ...); }
#else
#define D while (false)
#define LOG(...)
#endif
const int max_n = 3e5+10, inf = 1000111222;
const int md = 998244353;
void inc(int& a,int b)
{
a+=b;
if (a>=md){
a-=md;
}
}
int id_in_sorted[max_n];
int value_of_sorted[max_n];
struct segment_tree {
int f[4 * max_n];
int sum[4 * max_n];
int cnt_non_zero[4 * max_n];
int sum_value_cnt[4 * max_n];
void init()
{
for (int i=0;i<4*max_n;i++){
f[i]=1;
}
}
void push(int v) {
if (f[v] != 1) {
sum[2 * v] = 1ll * sum[2 * v] * f[v] % md;
sum[2 * v + 1] = 1ll * sum[2 * v + 1] * f[v] % md;
sum_value_cnt[2 * v] = 1ll * sum_value_cnt[2 * v] * f[v] % md;
sum_value_cnt[2 * v + 1] = 1ll * sum_value_cnt[2 * v + 1] * f[v] % md;
f[2 * v] = 1ll * f[2 * v] * f[v] % md;
f[2 * v + 1] = 1ll * f[2 * v + 1] * f[v] % md;
f[v] = 1;
}
}
void update_mult(int v, int tl, int tr, int l, int r, int value) {
if (tl == l && tr == r) {
f[v] = 1ll * f[v] * value % md;
sum[v] = 1ll * sum[v] * value % md;
sum_value_cnt[v] = 1ll * sum_value_cnt[v] * value % md;
return;
}
int mid = (tl + tr) / 2;
push(v);
if (r <= mid) {
update_mult(2 * v, tl, mid, l, r, value);
} else if (l > mid) {
update_mult(2 * v + 1, mid + 1, tr, l, r, value);
} else {
update_mult(2 * v, tl, mid, l, mid, value);
update_mult(2 * v + 1, mid + 1, tr, mid + 1, r, value);
}
sum[v] = sum[2 * v] + sum[2 * v + 1];
if (sum[v]>=md){
sum[v]-=md;
}
sum_value_cnt[v] = sum_value_cnt[2 * v] + sum_value_cnt[2 * v + 1];
if (sum_value_cnt[v]>=md){
sum_value_cnt[v]-=md;
}
}
void update_element(int v, int tl, int tr, int pos, int value) {
if (tl == tr) {
LOG("element_updated");
sum[v] = value;
sum_value_cnt[v] = 1ll * value_of_sorted[pos] * value % md;
cnt_non_zero[v]++;
return;
}
int mid = (tl + tr) / 2;
push(v);
if (pos <= mid) {
update_element(2 * v, tl, mid, pos, value);
} else {
update_element(2 * v + 1, mid + 1, tr, pos, value);
}
sum[v] = sum[2 * v] + sum[2 * v + 1];
if (sum[v]>=md){
sum[v]-=md;
}
sum_value_cnt[v] = sum_value_cnt[2 * v] + sum_value_cnt[2 * v + 1];
if (sum_value_cnt[v]>=md){
sum_value_cnt[v]-=md;
}
cnt_non_zero[v] = cnt_non_zero[2 * v] + cnt_non_zero[2 * v + 1];
}
int get_sum(int v, int tl, int tr, int l, int r) {
if (tl == l && tr == r) {
return sum[v];
}
int mid = (tl + tr) / 2;
push(v);
if (r <= mid) {
return get_sum(2 * v, tl, mid, l, r);
} else if (l > mid) {
return get_sum(2 * v + 1, mid + 1, tr, l, r);
}
int res = get_sum(2 * v, tl, mid, l, mid) + get_sum(2 * v + 1, mid + 1, tr, mid + 1, r);
if (res>=md){
res-=md;
}
return res;
}
int get_sum_value_cnt(int v, int tl, int tr, int l, int r) {
if (tl == l && tr == r) {
return sum_value_cnt[v];
}
int mid = (tl + tr) / 2;
push(v);
if (r <= mid) {
return get_sum_value_cnt(2 * v, tl, mid, l, r);
} else if (l > mid) {
return get_sum_value_cnt(2 * v + 1, mid + 1, tr, l, r);
}
int res = get_sum_value_cnt(2 * v, tl, mid, l, mid) + get_sum_value_cnt(2 * v + 1, mid + 1, tr, mid + 1, r);
if (res>=md){
res-=md;
}
return res;
}
void build_into_vector_and_clear(int v,int tl,int tr,vector<pii>& res) {
if (cnt_non_zero[v]==0){
return;
}
cnt_non_zero[v]=0;
if (tl == tr) {
res.pb(mp(tl,sum[v]));
sum[v]=0;
sum_value_cnt[v] = 0;
return;
}
int mid = (tl + tr) / 2;
push(v);
build_into_vector_and_clear(2 * v, tl, mid, res);
build_into_vector_and_clear(2 * v + 1, mid + 1, tr, res);
sum[v]=0;
sum_value_cnt[v] = 0;
}
};
segment_tree st;
int n;
int a[max_n];
vector<int> reb[max_n];
int sz[max_n];
void dfs0(int now,int pred)
{
sz[now]=1;
for (auto wh:reb[now]){
if (wh!=pred){
dfs0(wh,now);
sz[now]+=sz[wh];
}
}
}
int pw2[max_n];
int ans[max_n];
vector<pii> dp[max_n];
void build_state_from_current_st(vector<pii>& dp)
{
st.build_into_vector_and_clear(1,0,n-1,dp);
}
void merge_state_into_st(int v,int son)
{
D{
LOG("merge_state_into_st",v,son);
cerr<<"dp :: ";
for (auto i:dp[son]){
cerr<<value_of_sorted[i.fir]<<","<<i.sec<<" ";
}
cerr<<"\n";
cerr<<"dp raw indices :: ";
for (auto i:dp[son]){
cerr<<i.fir<<","<<i.sec<<" ";
}
cerr<<"\n";
}
inc(ans[v],1ll*ans[son]*pw2[(sz[v]-1)-(sz[son]-1)]%md);
vector<int> ways_a(len(dp[son]));
for (int ii=0;ii<len(dp[son]);ii++){
pii i=dp[son][ii];
if (i.fir!=n-1){
ways_a[ii]=st.get_sum(1,0,n-1,i.fir+1,n-1);
}
}
/// edge == 0
{
for (auto i:dp[son]){
inc(ans[v],1ll*value_of_sorted[i.fir]*i.sec%md*pw2[(sz[v]-1)-(sz[son]-1)-1]%md);
}
LOG("edge == 0","finished");
}
/// edge == 1
{
int last_updated=n;
reverse(all(dp[son]));
int current_ways=pw2[sz[son]-1];
for (auto i:dp[son]){
LOG("edge == 1, doing first",i.fir,i.sec);
if (i.fir+1<=last_updated-1){
LOG("update_mult on subseg",i.fir+1,last_updated-1,current_ways);
st.update_mult(1,0,n-1,i.fir+1,last_updated-1,current_ways);
last_updated=i.fir+1;
}
inc(current_ways,i.sec);
}
if (0<=last_updated-1){
LOG("update_mult on subseg",0,last_updated-1,current_ways);
st.update_mult(1,0,n-1,0,last_updated-1,current_ways);
last_updated=0;
}
reverse(all(dp[son]));
LOG("edge == 1","finished");
}
for (int ii=0;ii<len(dp[son]);ii++){
pii i=dp[son][ii];
st.update_element(1,0,n-1,i.fir,1ll*ways_a[ii]*i.sec%md);
}
}
void dfs1(int now,int pred)
{
LOG("dfs1",now,pred);
sort(all(reb[now]),[&](const int& lhs,const int& rhs){
return sz[lhs]>sz[rhs];
});
if (pred!=-1){
reb[now].erase(reb[now].begin());
}
if (len(reb[now])==0){
LOG(now,"leaf",id_in_sorted[now]);
st.update_element(1,0,n-1,id_in_sorted[now],1);
return;
}
for (int i=1;i<len(reb[now]);i++){
dfs1(reb[now][i],now);
build_state_from_current_st(dp[reb[now][i]]);
}
dfs1(reb[now][0],now);
/// update_root_into_st
{
const int v=now;
const int son=reb[now][0];
LOG("base v son",v,son);
inc(ans[v],1ll*ans[son]*pw2[(sz[v]-1)-(sz[son]-1)]%md);
{
int st_sum_value_cnt=st.get_sum_value_cnt(1,0,n-1,0,n-1);
LOG(v,son,st_sum_value_cnt);
inc(ans[v],1ll*st_sum_value_cnt*pw2[(sz[v]-1)-(sz[son]-1)-1]%md);
LOG("ans after this st_sum_value_cnt is ",ans[v]);
}
{
int sum_greater_equal=st.get_sum(1,0,n-1,id_in_sorted[now],n-1);
st.update_mult(1,0,n-1,id_in_sorted[now],n-1,0);
LOG(v,(sum_greater_equal+pw2[sz[son]-1])%md);
st.update_element(1,0,n-1,id_in_sorted[now],(sum_greater_equal+pw2[sz[son]-1])%md);
}
}
LOG("before merges",now,ans[now]);
for (int i=1;i<len(reb[now]);i++){
merge_state_into_st(now,reb[now][i]);
LOG("after merge with",reb[now][i],now,ans[now]);
}
LOG(now,ans[now]);
}
int main() {
freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
ios_base::sync_with_stdio(0);
cin.tie(0);
pw2[0]=1;
for (int i=1;i<max_n;i++){
pw2[i]=1ll*pw2[i-1]*2%md;
}
cin>>n;
vector<pii> aa(n);
for (int i=0;i<n;i++){
cin>>a[i];
aa[i]=mp(a[i],i);
}
sort(all(aa));
for (int i=0;i<n;i++){
id_in_sorted[aa[i].sec]=i;
value_of_sorted[i]=aa[i].fir;
}
for (int i=1;i<n;i++){
int u,v;
cin>>u>>v;
u--;
v--;
reb[u].pb(v);
reb[v].pb(u);
}
dfs0(0,-1);
st.init();
dfs1(0,-1);
int answer=(ans[0]+st.get_sum_value_cnt(1,0,n-1,0,n-1))%md;
LOG(st.get_sum_value_cnt(1,0,n-1,0,n-1));
D{
vector<pii> dp0;
build_state_from_current_st(dp0);
cerr<<"dp0 :: "<<"\n";
for (auto i:dp0){
cerr<<value_of_sorted[i.fir]<<" "<<i.sec<<"\n";
}
};
cout<<answer<<"\n";
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 0ms
memory: 26868kb
input:
4 1 2 3 4 1 2 2 4 3 2
output:
0
result:
wrong answer 1st numbers differ - expected: '44', found: '0'