QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#104522 | #6322. Forestry | zhuyifan | WA | 2ms | 3416kb | C++14 | 6.1kb | 2023-05-10 22:18:45 | 2023-05-10 22:18:46 |
Judging History
answer
#include<bits/stdc++.h>
#define LL long long
#define SZ(x) ((LL)(x.size()))
using namespace std;
long long read(){
long long q=0,w=1;
char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){q=q*10+(ch-'0');ch=getchar();}
return q*w;
}
void write(LL x){
if(x<0){putchar('-');x=(-x);}
if(x>9)write(x/10);
putchar('0'+x%10);
}
void writeln(LL x){write(x);puts("");}
void writecs(LL x){write(x);putchar(' ');}
const long long mod = 998244353;
void addmod(LL &x,LL y){x+=y;if(x>=mod)x-=mod;}
void submod(LL &x,LL y){x-=y;if(x<0)x+=mod;}
LL domod(LL x){
if(x<0)x+=mod;
if(x>=mod)x-=mod;
return x;
}
namespace matrix{
struct mat{LL n,m,a[3][3];};
mat operator*(mat x,mat y){
mat z;z.n=x.n;z.m=y.m;
for(LL i=0;i<z.n;i++)
for(LL j=0;j<z.m;j++)z.a[i][j]=0;
for(LL k=0;k<x.m;k++)
for(LL i=0;i<x.n;i++)
for(LL j=0;j<y.m;j++)
addmod(z.a[i][j],x.a[i][k]*y.a[k][j]%mod);
return z;
}
}using matrix::mat;
const long long N = 300000+95;
long long n,a[N];
struct Edge{
LL to,nxt;
}e[N<<1];LL head[N],tot;
void add_e(LL u,LL v){
e[++tot].to=v;e[tot].nxt=head[u];
head[u]=tot;return ;
}
long long fa[N],dep[N],sz[N],son[N];
void dfs(LL x){
dep[x]=dep[fa[x]]+1;sz[x]=1;
for(LL i=head[x];i;i=e[i].nxt){
if(e[i].to==fa[x])continue;
fa[e[i].to]=x;dfs(e[i].to);sz[x]+=sz[e[i].to];
if(sz[son[x]]<sz[e[i].to])son[x]=e[i].to;
}
return ;
}
long long dfn[N],low[N],ed[N],tim,seq[N],pf[N];
void dfs2(LL x,LL ac){
dfn[x]=++tim;seq[tim]=x;ed[ac]=tim;pf[x]=ac;
if(son[x])dfs2(son[x],ac);
for(LL i=head[x];i;i=e[i].nxt){
if(e[i].to==fa[x]||e[i].to==son[x])continue;
dfs2(e[i].to,e[i].to);
}
low[x]=tim;return ;
}
LL ord[N];
bool cmp(LL x,LL y){return a[x]>a[y];}
LL dp[N][3];bool vis[N];
const mat unit = (mat){3,3,{{1,0,0},{0,1,0},{0,0,1}}};
namespace SGT{
struct node{LL l,r;mat d;}s[N<<2];
void pushup(LL p){s[p].d=s[s[p].l].d*s[s[p].r].d;}
void build(LL &p,LL l,LL r){
p=++tot;
if(l>=r){s[p].d=unit;return ;}
LL mid=(l+r)>>1;
build(s[p].l,l,mid);
build(s[p].r,mid+1,r);
pushup(p);return ;
}
void update(LL p,LL x,mat v,LL l,LL r){
if(l==r){s[p].d=v;return ;}
LL mid=(l+r)>>1;
if(x<=mid)update(s[p].l,x,v,l,mid);
else update(s[p].r,x,v,mid+1,r);
pushup(p);return ;
}
mat query(LL p){return s[p].d;}
}
LL rt[N],lim[N],id[N];
mat make(LL x){return (mat){1,3,{1,vis[x],0}};}
mat upd(LL x){return (make(x)*SGT::query(rt[x]));}
mat trans(LL y,bool type=false){
if(!type){
mat now=upd(y);
dp[y][0]=now.a[0][0];dp[y][1]=now.a[0][1];dp[y][2]=now.a[0][2];
}
return (mat){3,3,{{(dp[y][0]*2ll)%mod,0,(dp[y][2]*2ll+dp[y][1])%mod},
{0,(dp[y][0]+dp[y][1])%mod,0},
{0,0,dp[y][0]*2ll%mod}}};
}
void DFS(LL x){
// cout<<"> DFS: x = "<<x<<endl;
if(!vis[x]){dp[x][0]=1;dp[x][1]=0;dp[x][2]=0;}
else {dp[x][0]=1;dp[x][1]=1;dp[x][2]=0;}
for(LL i=head[x];i;i=e[i].nxt){
if(e[i].to==fa[x])continue;
DFS(e[i].to);LL y=e[i].to;
LL dp0=dp[x][0],dp1=dp[x][1],dp2=dp[x][2];
dp[x][0]=(dp0*dp[y][0]*2ll%mod);
dp[x][1]=(dp1*(dp[y][0]+dp[y][1])%mod);
dp[x][2]=(dp2*dp[y][0]*2ll%mod+dp0*dp[y][1]+dp0*dp[y][2]*2ll%mod)%mod;
if(y!=son[x]){id[y]=++lim[x];}
}
// cout<<" x = "<<x<<" lim[x] = "<<lim[x]<<endl;
SGT::build(rt[x],1,lim[x]);
for(LL i=head[x];i;i=e[i].nxt){
if(e[i].to==fa[x]||e[i].to==son[x])continue;
SGT::update(rt[x],id[e[i].to],trans(e[i].to,true),1,lim[x]);
}
return ;
}
namespace seg{
struct node{LL l,r;mat d;}s[N<<2];
void pushup(LL p){s[p].d=s[p*2].d*s[p*2+1].d;}
void build(LL p,LL l,LL r){
s[p].l=l;s[p].r=r;
if(l==r){s[p].d=trans(seq[l]);return ;}
LL mid=(l+r)>>1;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
pushup(p);return ;
}
void update(LL p,LL x){
if(s[p].l==s[p].r){s[p].d=trans(seq[s[p].l]);return ;}
LL mid=(s[p].l+s[p].r)>>1;
if(x<=mid)update(p*2,x);
else update(p*2+1,x);
pushup(p);return ;
}
mat query(LL p,LL l,LL r){
if(l<=s[p].l&&s[p].r<=r)return s[p].d;
if(r<s[p].l || s[p].r<l || l>r)return unit;
LL mid=(s[p].l+s[p].r)>>1;
if(r<=mid)return query(p*2,l,r);
if(mid<l)return query(p*2+1,l,r);
return (query(p*2,l,r)*query(p*2+1,l,r));
}
}
void update(LL x){
while(x){
seg::update(1,dfn[x]);
if(fa[pf[x]]){
mat vl=seg::query(1,dfn[pf[x]],ed[pf[x]]);
SGT::update(rt[fa[pf[x]]],id[pf[x]],vl,1,lim[fa[pf[x]]]);
}
x=fa[pf[x]];
}
return ;
}
LL query(){
// mat vl=seg::query(1,dfn[1]+1,ed[1]);
/* cout<<" vl.n = "<<vl.n<<" vl.m = "<<vl.m<<endl;
for(LL i=0;i<vl.n;i++){
for(LL j=0;j<vl.m;j++)
cout<<vl.a[i][j]<<" ";
cout<<endl;
}*/
mat v=(upd(1)*seg::query(1,dfn[1]+1,ed[1]));
return domod(v.a[0][1]+v.a[0][2]);
}
int main(){
n=read();
for(LL i=1;i<=n;i++)a[i]=read();
for(LL i=1;i<n;i++){
LL u=read(),v=read();
add_e(u,v);add_e(v,u);
}
dfs(1);dfs2(1,1);
DFS(1);seg::build(1,1,n);
for(LL i=1;i<=n;i++)ord[i]=i;
sort(ord+1,ord+n+1,cmp);
LL ans=0;
for(LL l=1,r=0;l<=n;l=r+1){
r=l;while(r<n&&a[ord[r+1]]==a[ord[l]])r++;
for(LL i=l;i<=r;i++){vis[ord[i]]=true;update(ord[i]);}
LL SUM=query();
addmod(ans,SUM*(a[ord[r]]-a[ord[r+1]])%mod);
// cout<<" SUM = "<<SUM<<" l = "<<l<<" r = "<<r<<" a[ord[l]] = "<<a[ord[l]]<<endl;
/* for(LL i=1;i<=n;i++)
cout<<vis[i]<<" ";
cout<<endl;
for(LL i=1;i<=n;i++)
cout<<" i = "<<i<<" dp[i][0] = "<<dp[i][0]<<" dp[i][1] = "<<dp[i][1]<<" dp[i][2] = "<<dp[i][2]<<" son[i] = "<<son[i]<<endl;
cout<<endl;*/
/* for(LL i=1;i<=n;i++){
cout<<" i = "<<i<<" seq[i] = "<<seq[i]<<endl;
mat v=seg::query(1,i,i);
cout<<" v.n = "<<v.n<<" v.m = "<<v.m<<endl;
for(LL i=0;i<v.n;i++){
for(LL j=0;j<v.m;j++)
cout<<v.a[i][j]<<" ";
cout<<endl;
}
cout<<endl;
}*/
}
writeln(ans);
return 0;
}
/*
my hack data:
input:
3
2 1 1
1 2
2 3
output:
10
*/
詳細信息
Test #1:
score: 100
Accepted
time: 0ms
memory: 3416kb
input:
4 1 2 3 4 1 2 2 4 3 2
output:
44
result:
ok 1 number(s): "44"
Test #2:
score: -100
Wrong Answer
time: 2ms
memory: 3376kb
input:
5 3 5 6 5 1 4 1 2 3 3 5 1 3
output:
164
result:
wrong answer 1st numbers differ - expected: '154', found: '164'