#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx2")
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = int; using pii = pair<ll,ll>;
const ll Nm = 2e5+5;
ll N;
long long ans = 0;
vector<int> F,U,V;
vector<pii> locs[Nm]; //{index of subtree, index in subtree}
ll rlbl[Nm];
//vector<ll> hld;
const ll Sm = 4194304; const ll E = 22;
//long long sts0[2*Sm]; //already pushed sum s0
int na0[2*Sm]; //number active in 0
int pd0[2*Sm]; //number to push down
//long long sts2[2*Sm];
int na2[2*Sm];
int pd2[2*Sm];
ll allc = 0; //allocator for memory
ll v2(ll x) {
return __builtin_ctz(x);
}
inline void pdn0(ll p) { //push down at position p
//return;
//sts0[p]+=stna0[p]*pd0[p];
pd0[2*p]+=pd0[p];
pd0[2*p+1]+=pd0[p];
pd0[p]=0;
}
inline void pdn2(ll p) { //push down at position p
//return;
//sts2[p]+=stna2[p]*pd2[p];
pd2[2*p]+=pd2[p];
pd2[2*p+1]+=pd2[p];
pd2[p]=0;
}
ll wrt0(ll x, ll v) { //v is the DELTA
//return 0;
for (ll e=16;e>0;e--) {
pdn0((x>>e)+(1LL<<(E-e)));
}
//sts0[x+Sm] += v*pd0[x+Sm];
na0[x+Sm] += v;
for (ll e=1;e<=16;e++) {
ll p = ((x>>e)+(1LL<<(E-e)));
//sts0[p] += v*pd0[x+Sm];
na0[p] += v;
}
return v*pd0[x+Sm];
}
ll wrt2(ll x, ll v) {
//return 0;
for (ll e=16;e>0;e--) {
pdn2((x>>e)+(1LL<<(E-e)));
}
// sts2[x+Sm] += v*pd2[x+Sm];
na2[x+Sm] += v;
for (ll e=1;e<=16;e++) {
ll p = ((x>>e)+(1<<(E-e)));
//sts2[p] += v*pd2[x+Sm];
na2[p] += v;
}
return v*pd2[x+Sm];
}
pii wrtI(ll x, ll y, ll v) { //return {n21, n01} updates
//cout << "write 1 to range "<<x<<","<<y<<"\n";
//return {0,0};
if (x>y) {
return {0,0};
}
ll vx = v2(x); ll vy = v2(y+1);
if (vx<vy) {
ll p = (x>>vx)+(1<<(E-vx));
pii p1 = {v*na0[p],v*na2[p]};
pd0[p]+=v; pd2[p]+=v;
pii p2 = wrtI(x+(1<<vx),y,v);
return {p1.first+p2.first,p1.second+p2.second};
} else {
ll p = (y>>vy)+(1<<(E-vy));
pii p1 = {v*na0[p],v*na2[p]};
pd0[p]+=v; pd2[p]+=v;
pii p2 = wrtI(x,y-(1<<vy),v);
return {p1.first+p2.first,p1.second+p2.second};
}
}
struct cst { //cdt subtree
ll M;
long long n0,n2,n21,n01;
vector<ll> Fn;
ll r;
vector<vector<ll>> fadj;
vector<ll> radj;
vector<int> tsz; //subtree size
vector<int> sti; //segtree index
// vector<long long> v0,v2,v21,v01;
/*void lft(ll x) {
v0[x]=(Fn[x]==0);
v2[x]=(Fn[x]==2);
v21[x]=0;
v01[x]=0;
for (ll y: fadj[x]) {
lft(y);
v0[x]+=v0[y];
v2[x]+=v2[y];
v01[x]+=v01[y];
v21[x]+=v21[y];
}
if (Fn[x]==1) {
v01[x]+=v0[x];
v21[x]+=v2[x];
}
}
void calc() {
vector<long long> emp(M,0);
v0=emp; v2=emp; v21=emp; v01=emp;
lft(r);
n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r];
v0.clear(); v2.clear(); v21.clear(); v01.clear();
}*/
cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) {
n0=0; n2=0; n21=0; n01=0;
r=r0; M=M0;
Fn=f0;
vector<bool> found;
for (ll m=0;m<M;m++) {
tsz.push_back(0);
sti.push_back(0);
radj.push_back(-1);
found.push_back(0);
fadj.push_back((vector<ll>){});
}
// queue<ll> q;
// q.push(r);
// while (!q.empty()) {
// ll x = q.front(); q.pop();
// found[x]=1;
// for (ll y: adj[x]) {
// if (!found[y]) {
// radj[y]=x;
// q.push(y);
// fadj[x].push_back(y);
// }
// }
// }
stack<pii> q0;
q0.push({r,0});
while (!q0.empty()) {
pii p0 = q0.top(); q0.pop();
ll x = p0.first; ll t = p0.second;
if (t==0) {
found[x]=1;
q0.push({x,1});
for (ll y: adj[x]) {
if (!found[y]) {
radj[y]=x;
q0.push({y,0});
fadj[x].push_back(y);
}
}
} else {
tsz[x]=1;
// cout << "x="<<x<<", Fn[x]="<<Fn[x]<<"\n";
for (ll y: adj[x]) {
// cout << "y in fadj="<<y<<"\n";
if (radj[y]==x) {
//cout << "f1\n";
tsz[x]+=tsz[y];
}
}
//cout << "tsz[x]="<<tsz[x]<<"\n";
sti[x]=allc;
if (f0[x]==0) {
n0++;
n01 += wrt0(allc,1); //v is the DELTA
} else if (f0[x]==2) {
n2++;
n21 += wrt2(allc,1);
} else {
pii p1mod = wrtI(allc-tsz[x]+1,allc,1);
n01 = p1mod.first+n01;
n21 = p1mod.second+n21;
}
allc++;
}
}
//cout << "n0,n2,n01,n21="<<n0<<","<<n2<<","<<n01<<","<<n21<<"\n";
//calc();
}
/*void upd(ll x, ll v) {
Fn[x]=v;
calc();
}*/
void upd(ll x, ll v) {
if (Fn[x]==0) {
n0--;
n01 += wrt0(sti[x],-1);
} else if (Fn[x]==2) {
n2--;
n21 += wrt2(sti[x],-1);
} else {
pii p1mod = wrtI(sti[x]-tsz[x]+1,sti[x],-1);
n01 = p1mod.first+n01;
n21 = p1mod.second+n21;
}
Fn[x]=v;
if (Fn[x]==0) {
n0++;
n01 += wrt0(sti[x],1);
} else if (Fn[x]==2) {
n2++;
n21 += wrt2(sti[x],1);
} else {
pii p1mod = wrtI(sti[x]-tsz[x]+1,sti[x],1);
n01 = p1mod.first+n01;
n21 = p1mod.second+n21;
}
// cout << "n0,n2,n01,n21="<<n0<<","<<n2<<","<<n01<<","<<n21<<"\n";
}
};
struct cdt { //centroid decomp tree
ll M; //size
vector<vector<ll>> fadj;
vector<ll> Fn; //new F
vector<pii> strl; //subtree locations: {index of st, index in st}
vector<cst*> v1;
long long s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0;
cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops
M=M1; fadj=fadj1; Fn=Fn1;
for (ll m=0;m<M;m++) {
strl.push_back((pii){0,0});
}
ll rcnt = 0;
for (ll x: fadj[0]) {
//unordered_map<ll,ll> rlbl; //relabel
vector<vector<ll>> nadj;
vector<ll> fnew;
ll Mn = 0;
queue<pii> q0;
q0.push({x,-1});
//cout << "x="<<x<<"\n";
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
if (z==0) {
continue;
}
//if (rlbl.find(z)==rlbl.end()) {
rlbl[z]=Mn++;
//cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n";
nadj.push_back((vector<ll>){});
fnew.push_back(Fn[z]);
strl[z]={rcnt,rlbl[z]};
//cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n";
//locs[z].push_back({dind,rlbl[z]});
// }
if (pz != -1) {
//cout << "z,pz="<<z<<","<<pz<<"\n";
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll nz: fadj[z]) {
if (nz != pz && nz != 0) {
q0.push({nz,z});
}
}
}
v1.push_back(new cst(0LL,Mn,nadj,fnew));
rcnt++;
}
for (ll r=0;r<rcnt;r++) {
s21 += (v1[r]->n21);
s01 += (v1[r]->n01);
s0 += (v1[r]->n0);
s2 += v1[r]->n2;
s210 += (v1[r]->n21)*(v1[r]->n0);
s012 += (v1[r]->n01)*(v1[r]->n2);
s02 += (v1[r]->n0)*(v1[r]->n2);
}
ans += (s21*s0-s210+s01*s2-s012);
if (Fn[0]==0) {
ans += s21;
} else if (Fn[0]==1) {
ans += (s0*s2-s02);
} else {
ans += s01;
}
}
void upd(ll x, ll vf) {
if (x==0) {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
if (vf==0) {
ans += s21;
} else if (vf==1) {
ans += (s0*s2-s02);
} else {
assert(vf==2);
ans += s01;
}
} else {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
ans -= (s21*s0-s210+s01*s2-s012);
ll i = strl[x].first;
s21 -= (v1[i]->n21);
s01 -= (v1[i]->n01);
s2 -= (v1[i]->n2);
s0 -= (v1[i]->n0);
s210 -= (v1[i]->n21)*(v1[i]->n0);
s012 -= (v1[i]->n01)*(v1[i]->n2);
s02 -= (v1[i]->n0)*(v1[i]->n2);
(*v1[i]).upd(strl[x].second,vf);
i = strl[x].first;
s21 += (v1[i]->n21);
s01 += (v1[i]->n01);
s2 += (v1[i]->n2);
s0 += (v1[i]->n0);
s210 += (v1[i]->n21)*(v1[i]->n0);
s012 += (v1[i]->n01)*(v1[i]->n2);
s02 += (v1[i]->n0)*(v1[i]->n2);
v0 = Fn[0];
if (v0==0) {
ans += s21;
} else if (v0==1) {
ans += (s0*s2-s02);
} else {
assert(v0==2);
ans += s01;
}
ans += (s21*s0-s210+s01*s2-s012);
}
Fn[x]=vf;
}
};
vector<ll> adj[Nm];
bool found[Nm];
ll sz[Nm];
ll rev[Nm];
vector<cdt*> cdtr;
ll getsz(ll x, ll pr = -1) {
sz[x]=1;
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
sz[x]+=getsz(y,x);
}
}
return sz[x];
}
ll getctr(ll x, ll sz0, ll pr=-1) {
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
if (2*sz[y]>=sz0) {
return getctr(y,sz0,x);
}
}
}
return x;
}
ll dind = 0; //index in cdtr
void bldDcmp(ll x=0) { //start, previous
ll sz0 = getsz(x);
ll y = getctr(x,sz0);
vector<vector<ll>> nadj; //new adjacency
vector<ll> fnew;
ll M = 0;
queue<pii> q0;
q0.push({y,-1});
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
rlbl[z]=M++;
nadj.push_back((vector<ll>){});
fnew.push_back(F[z]);
locs[z].push_back({dind,rlbl[z]});
if (pz != -1) {
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll zn: adj[z]) {
if (!found[zn] && zn != pz) {
q0.push({zn,z});
}
}
}
cdtr.push_back(new cdt(M,nadj,fnew));
found[y]=1;
dind++;
for (ll z: adj[y]) {
if (!found[z]) {
bldDcmp(z);
}
}
}
void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1, int Q) {
N=N1;
F=F1;
U=U1;
V=V1;
for (ll i=0;i<(N-1);i++) {
adj[U[i]].push_back(V[i]);
adj[V[i]].push_back(U[i]);
}
bldDcmp(); //build centroid decomposition
}
void change(int x, int y) {
for (pii p0: locs[x]) {
(*cdtr[p0.first]).upd(p0.second,y);
}
}
long long num_tours() {
return ans;
}