QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#104800 | #6322. Forestry | Kevin5307 | WA | 7ms | 60100kb | C++14 | 3.3kb | 2023-05-12 00:02:20 | 2023-05-12 00:02:22 |
Judging History
answer
//Author: Kevin5307
#include<bits/stdc++.h>
//#pragma GCC optimize("O2")
using namespace std;
#define ll long long
#define ull unsigned ll
#define pb push_back
#define mp make_pair
#define ALL(x) (x).begin(),(x).end()
#define rALL(x) (x).rbegin(),(x).rend()
#define srt(x) sort(ALL(x))
#define rev(x) reverse(ALL(x))
#define rsrt(x) sort(rALL(x))
#define sz(x) (int)(x.size())
#define inf 0x3f3f3f3f
#define pii pair<int,int>
#define lb(v,x) (int)(lower_bound(ALL(v),x)-v.begin())
#define ub(v,x) (int)(upper_bound(ALL(v),x)-v.begin())
#define uni(v) v.resize(unique(ALL(v))-v.begin())
#define longer __int128_t
void die(string S){puts(S.c_str());exit(0);}
namespace _kv_hash
{
const ull mask=std::chrono::steady_clock::now().time_since_epoch().count();
ull shift(ull x)
{
x^=mask;
x^=x<<13;
x^=x>>7;
x^=x<<17;
x^=mask;
return x;
}
}
using namespace _kv_hash;
const ll mod=998244353;
ll f[6006000],g[6006000],tag[6006000];
int ls[6006000],rs[6006000],tot;
vector<ll> vec;
ll a[300300];
vector<int> G[300300];
ll sumA,sumB;
void pushdown(int u)
{
if(!u) return ;
f[u]=f[u]*tag[u]%mod;
g[u]=g[u]*tag[u]%mod;
if(ls[u])
tag[ls[u]]=tag[ls[u]]*tag[u]%mod;
if(rs[u])
tag[rs[u]]=tag[rs[u]]*tag[u]%mod;
tag[u]=1;
return ;
}
int merge(int u,int v,int l,int r)
{
pushdown(u);
pushdown(v);
if(!u||!v)
{
sumA=(sumA+f[u])%mod;
sumB=(sumB+(f[v]*(mod+1)/2%mod)%mod)%mod;
if(!u) tag[v]=tag[v]*sumA%mod*((mod+1)/2)%mod;
if(!v) tag[u]=tag[u]*sumB%mod;
// cerr<<l<<" "<<r<<" "<<sumA<<" "<<sumB<<endl;
// f[u+v]=f[u];
// g[u+v]=g[u];
pushdown(u+v);
return u+v;
}
if(l==r)
{
ll val=(f[v]*(mod+1)/2%mod)%mod;
f[u]=(f[u]*val+f[u]*sumB+f[v]*sumA)%mod;
g[u]=(g[u]*val+g[u]*sumB+g[v]*sumA)%mod;
sumA=(sumA+f[u])%mod;
sumB=(sumB+val)%mod;
return u;
}
int mid=(l+r)/2;
rs[u]=merge(rs[u],rs[v],mid+1,r);
ls[u]=merge(ls[u],ls[v],l,mid);
f[u]=(f[ls[u]]+f[rs[u]])%mod;
g[u]=(g[ls[u]]+g[rs[u]])%mod;
// cerr<<l<<" "<<r<<" "<<sumA<<" "<<sumB<<" "<<f[u]<<" "<<g[u]<<endl;
return u;
}
int MergeTree(int u,int v)
{
sumA=sumB=0;
return merge(u,v,0,sz(vec)-1);
}
void update(int u,int l,int r,int p,ll v)
{
pushdown(u);
if(l==r)
{
f[u]=(f[u]+v)%mod;
g[u]=(g[u]+v*vec[p])%mod;
return ;
}
int mid=(l+r)/2;
if(p<=mid)
{
if(!ls[u])
ls[u]=++tot;
update(ls[u],l,mid,p,v);
}
else
{
if(!rs[u])
rs[u]=++tot;
update(rs[u],mid+1,r,p,v);
}
f[u]=(f[ls[u]]+f[rs[u]])%mod;
g[u]=(g[ls[u]]+g[rs[u]])%mod;
}
int rt[300300];
ll ans;
void dfs(int u,int fa)
{
rt[u]=++tot;
update(rt[u],0,sz(vec)-1,lb(vec,a[u]),1);
for(auto v:G[u])
if(v!=fa)
{
dfs(v,u);
// cerr<<u<<" "<<v<<":"<<endl;
update(rt[v],0,sz(vec)-1,sz(vec)-1,1);
rt[u]=MergeTree(rt[u],rt[v]);
}
ll v=g[rt[u]];
if(fa)
ans=(ans+v*(mod+1)/2)%mod;
else
ans=(ans+v)%mod;
}
int main()
{
// freopen("toptree.in","r",stdin);
// freopen("toptree.out","w",stdout);
for(int i=0;i<6006000;i++)
tag[i]=1;
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%lld",a+i);
vec.pb(a[i]);
}
srt(vec);
uni(vec);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
G[u].pb(v);
G[v].pb(u);
}
dfs(1,0);
for(int i=1;i<n;i++)
ans=(ans+ans)%mod;
printf("%lld\n",ans);
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 7ms
memory: 59708kb
input:
4 1 2 3 4 1 2 2 4 3 2
output:
44
result:
ok 1 number(s): "44"
Test #2:
score: -100
Wrong Answer
time: 2ms
memory: 60100kb
input:
5 3 5 6 5 1 4 1 2 3 3 5 1 3
output:
142
result:
wrong answer 1st numbers differ - expected: '154', found: '142'