//Author: Kevin
#pragma GCC optimize(3)
#pragma GCC optimize("-fgcse")
#pragma GCC target("avx","sse2")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
#include<bits/stdc++.h>
#include"joitour.h"
//#pragma GCC optimize("O2")
using namespace std;
#define ll long long
#define ull unsigned ll
#define pb emplace_back
#define mp make_pair
#define ALL(x) (x).begin(),(x).end()
#define rALL(x) (x).rbegin(),(x).rend()
#define srt(x) sort(ALL(x))
#define rev(x) reverse(ALL(x))
#define rsrt(x) sort(rALL(x))
#define sz(x) (int)(x.size())
#define inf 0x3f3f3f3f
#define pii pair<int,int>
#define lb(v,x) (int)(lower_bound(ALL(v),x)-v.begin())
#define ub(v,x) (int)(upper_bound(ALL(v),x)-v.begin())
#define uni(v) v.resize(unique(ALL(v))-v.begin())
#define longer __int128_t
void die(string S){puts(S.c_str());exit(0);}
const int maxn=200200;
class Tree
{
public:
struct Data
{
ll c0,c1,c2,c10,c12;
Data():c0(0),c1(0),c2(0),c10(0),c12(0){}
};
int n,cnt;
vector<ll> BIT[3],psum[3],type,subtree,mxd;
vector<vector<int>> G;
vector<Data> data;
ll S0,S1,S2,S10,S12,d1,d2;
inline void update(int tr,int x,int v)
{
while(x<=n)
{
BIT[tr][x]+=v;
x+=(x&(-x));
}
}
inline int query(int tr,int x)
{
int ret=psum[tr][x];
while(x)
{
ret+=BIT[tr][x];
x-=(x&(-x));
}
return ret;
}
inline ll query()
{
ll ret=S10*S2+S0*S12-d1;
if(type[1]==0) ret+=S12;
if(type[1]==1) ret+=S0*S2-d2;
if(type[1]==2) ret+=S10;
return ret;
}
void init(int tot)
{
cnt=S0=S1=S2=S10=S12=d1=d2=0;
n=tot;
data.resize(n+5);
type.resize(n+5);
subtree.resize(n+5);
mxd.resize(n+5);
G.resize(n+5);
for(int i=0;i<3;i++)
{
BIT[i].resize(n+5);
psum[i].resize(n+5);
}
}
void dfs(int u,int fa,int id)
{
subtree[u]=id;
mxd[u]=u;
for(auto v:G[u])
if(v!=fa)
{
if(!fa) id=++cnt;
dfs(v,u,id);
mxd[u]=max(mxd[u],mxd[v]);
}
}
ll build()
{
dfs(1,0,0);
for(int i=2;i<=n;i++)
if(type[i]==1)
{
data[subtree[i]].c1++;
psum[1][i]++;
psum[1][mxd[i]+1]--;
}
else
psum[type[i]][i]++;
for(int i=1;i<=n;i++)
for(int j=0;j<3;j++)
psum[j][i]+=psum[j][i-1];
for(int i=2;i<=n;i++)
{
if(type[i]==0)
{
data[subtree[i]].c0++;
data[subtree[i]].c10+=psum[1][i];
}
if(type[i]==1)
data[subtree[i]].c1++;
if(type[i]==2)
{
data[subtree[i]].c2++;
data[subtree[i]].c12+=psum[1][i];
}
}
for(int i=1;i<=cnt;i++)
{
S0+=data[i].c0;
S1+=data[i].c1;
S2+=data[i].c2;
S10+=data[i].c10;
S12+=data[i].c12;
d1+=data[i].c0*data[i].c12;
d1+=data[i].c2*data[i].c10;
d2+=data[i].c0*data[i].c2;
}
return query();
}
void update(int u,int t)
{
if(u!=1)
{
S0-=data[subtree[u]].c0;
S1-=data[subtree[u]].c1;
S2-=data[subtree[u]].c2;
S10-=data[subtree[u]].c10;
S12-=data[subtree[u]].c12;
d1-=data[subtree[u]].c0*data[subtree[u]].c12;
d1-=data[subtree[u]].c2*data[subtree[u]].c10;
d2-=data[subtree[u]].c0*data[subtree[u]].c2;
if(type[u]==0)
{
data[subtree[u]].c0--;
data[subtree[u]].c10-=query(1,u);
update(0,u,-1);
}
if(type[u]==1)
{
data[subtree[u]].c1--;
data[subtree[u]].c10-=query(0,mxd[u])-query(0,u-1);
data[subtree[u]].c12-=query(2,mxd[u])-query(2,u-1);
update(1,u,-1);
update(1,mxd[u]+1,1);
}
if(type[u]==2)
{
data[subtree[u]].c2--;
data[subtree[u]].c12-=query(1,u);
update(2,u,-1);
}
}
type[u]=t;
if(u!=1)
{
if(type[u]==0)
{
data[subtree[u]].c0++;
data[subtree[u]].c10+=query(1,u);
update(0,u,1);
}
if(type[u]==1)
{
data[subtree[u]].c1++;
data[subtree[u]].c10+=query(0,mxd[u])-query(0,u-1);
data[subtree[u]].c12+=query(2,mxd[u])-query(2,u-1);
update(1,u,1);
update(1,mxd[u]+1,-1);
}
if(type[u]==2)
{
data[subtree[u]].c2++;
data[subtree[u]].c12+=query(1,u);
update(2,u,1);
}
S0+=data[subtree[u]].c0;
S1+=data[subtree[u]].c1;
S2+=data[subtree[u]].c2;
S10+=data[subtree[u]].c10;
S12+=data[subtree[u]].c12;
d1+=data[subtree[u]].c0*data[subtree[u]].c12;
d1+=data[subtree[u]].c2*data[subtree[u]].c10;
d2+=data[subtree[u]].c0*data[subtree[u]].c2;
}
}
void addEdge(int u,int v)
{
G[u].pb(v);
G[v].pb(u);
}
void setType(int x,int t)
{
type[x]=t;
}
}tree[maxn];
namespace Solver
{
int n;
vector<int> G[maxn];
vector<int> vd[maxn];
int type[maxn],siz[maxn],mx[maxn],rt,tot,ban[maxn],f[maxn],dfn;
ll ans;
void dfs1(int u,int fa)
{
siz[u]=1;
mx[u]=0;
for(auto v:G[u])
if(v!=fa&&!ban[v])
{
dfs1(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],tot-siz[u]);
if(mx[u]<mx[rt])
rt=u;
}
void dfs2(int u,int fa,int ind)
{
dfn++;
vd[u].pb(dfn);
tree[ind].setType(dfn,type[u]);
int tmp=dfn;
for(auto v:G[u])
if(v!=fa&&!ban[v])
{;
tree[ind].addEdge(tmp,dfn+1);
dfs2(v,u,ind);
}
}
void build(int u,int lst)
{
rt=0;
dfs1(u,0);
u=rt;
f[u]=lst;
tree[u].init(tot);
dfn=0;
dfs2(u,0,u);
ans+=tree[u].build();
dfs1(u,0);
vector<pii> vec;
for(auto v:G[u])
if(!ban[v])
vec.pb(v,siz[v]);
ban[u]=1;
for(auto pr:vec)
{
tot=pr.second;
build(pr.first,u);
}
}
void build2()
{
queue<int> q;
q.push(1);
q.push(0);
q.push(n);
while(!q.empty())
{
int u=q.front();
q.pop();
int tmp=q.front();
q.pop();
tot=q.front();
q.pop();
rt=0;
dfs1(u,0);
u=rt;
f[u]=tmp;
tree[u].init(tot);
dfn=0;
dfs2(u,0,u);
ans+=tree[u].build();
dfs1(u,0);
ban[u]=1;
for(auto v:G[u])
if(!ban[v])
{
q.push(v);
q.push(u);
q.push(siz[v]);
}
}
}
void init(int N,vector<int> F,vector<int> U,vector<int> V,int Q)
{
n=N;
for(int i=1;i<=N;i++)
type[i]=F[i-1];
for(int i=0;i<N-1;i++)
{
G[U[i]+1].pb(V[i]+1);
G[V[i]+1].pb(U[i]+1);
}
mx[0]=inf;
tot=n;
build2();
}
void change(int u,int c)
{
type[u]=c;
int tmp=u;
int p=sz(vd[u])-1;
while(u)
{
int cdfn=vd[tmp][p--];
ans-=tree[u].query();
tree[u].update(cdfn,c);
ans+=tree[u].query();
u=f[u];
}
}
}
void init(int N,vector<int> F,vector<int> U,vector<int> V,int Q)
{
Solver::init(N,F,U,V,Q);
}
void change(int X,int Y)
{
Solver::change(X+1,Y);
}
ll num_tours()
{
return Solver::ans;
}