#include "joitour.h"
#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int R = 1 << 18;
struct Node {
array<int, 3> cnt;
array<long long, 3> sum_cnt;
array<long long, 3> sum_heavy;
Node() {
cnt.fill(0);
sum_cnt.fill(0);
sum_heavy.fill(0);
}
Node operator+(Node b) {
Node ans;
for (int i = 0; i < 3; ++i) {
ans.cnt[i] = cnt[i] + b.cnt[i];
ans.sum_cnt[i] = sum_cnt[i] + b.sum_cnt[i];
ans.sum_heavy[i] = sum_heavy[i] + b.sum_heavy[i];
// if (cnt[1]) {
// ans.sum_cnt[i] += sum_cnt[i];
// ans.sum_heavy[i] += sum_heavy[i];
// }
// if (b.cnt[1]) {
// ans.sum_cnt[i] += b.sum_cnt[i];
// ans.sum_heavy[i] += b.sum_heavy[i];
// }
}
return ans;
}
// void add(Node b) {
// for (int i = 0; i < 3; ++i) {
// cnt[i] += b.cnt[i];
// sum_cnt[i] += b.sum_cnt[i];
// sum_heavy[i] += b.sum_heavy[i];
// }
// }
};
constexpr int maxn = 2e5 + 5;
vector<int> g[maxn];
int pr[maxn], sz[maxn];
void dfs(int v) {
sz[v] = 1;
for (auto u : g[v]) {
if (u != pr[v]) {
pr[u] = v;
dfs(u);
sz[v] += sz[u];
}
}
for (int i = 0; i < (int)g[v].size(); ++i) {
if (g[v][i] == pr[v]) {
swap(g[v][i], g[v].back());
}
}
if (v != 0) {
g[v].pop_back();
}
for (int i = 0; i < (int)g[v].size(); ++i) {
if (sz[g[v][i]] > sz[g[v][0]]) {
swap(g[v][i], g[v][0]);
}
}
}
int tp[maxn];
int cnt_all[3];
int cnt[maxn][3], cnt_heavy[maxn][3];
int t = 0;
int tin[maxn], tout[maxn];
int ord[maxn];
void calc_cnt(int v) {
tin[v] = t++;
ord[tin[v]] = v;
cnt_all[tp[v]]++;
cnt[v][tp[v]]++;
cnt_heavy[v][tp[v]]++;
for (auto u : g[v]) {
calc_cnt(u);
for (int i = 0; i < 3; ++i) {
cnt[v][i] += cnt[u][i];
}
}
if (!g[v].empty()) {
for (int i = 0; i < 3; ++i) {
cnt_heavy[v][i] += cnt[g[v][0]][i];
}
}
tout[v] = t;
}
long long sum[maxn];
long long cnt_sum[3];
long long ans = 0;
Node vals[maxn];
void init(int32_t n, vector<int32_t> f, vector<int32_t> u, vector<int32_t> v, int32_t q) {
for (int i = 0; i < n; ++i) {
tp[i] = f[i];
}
for (int i = 0; i < n - 1; ++i) {
g[u[i]].push_back(v[i]);
g[v[i]].push_back(u[i]);
}
dfs(0);
calc_cnt(0);
ans = cnt_all[0] * 1LL * cnt_all[1] * 1LL * cnt_all[2];
for (int i = 0; i < n; ++i) {
if (tp[i] == 1) {
for (auto j : g[i]) {
ans -= cnt[j][0] * 1LL * cnt[j][2];
}
for (int j = 0; j < 3; ++j) {
cnt_sum[j] += cnt[i][j];
}
ans -= (cnt_all[0] - cnt[i][0]) * 1LL * (cnt_all[2] - cnt[i][2]);
}
for (auto j : g[i]) {
if (j != g[i][0]) {
sum[i] += cnt[j][0] * 1LL * cnt[j][2];
}
}
vals[i].cnt[tp[i]]++;
for (int j = 0; j < 3; ++j) {
vals[i].sum_cnt[j] = cnt[i][j];
vals[i].sum_heavy[j] = cnt_heavy[i][j];
}
}
}
void add_to_root(int v, int vl, int tp) {
vals[v].cnt[vl] += tp;
vals[v].sum_cnt[vl] += tp;
vals[v].sum_heavy[vl] += tp;
int prv = v, nw = pr[v];
while (prv != 0) {
vals[nw].sum_cnt[vl] += tp;
if (prv == g[nw][0]) {
vals[nw].sum_heavy[vl] += tp;
} else {
if (vl != 1) {
sum[nw] += tp * cnt[prv][vl ^ 2];
}
}
prv = nw;
nw = pr[nw];
}
}
pair<Node, long long> get_to_root(int v) {
int prv = v, nw = pr[v];
Node ans;
long long ret = 0;
while (prv != 0) {
if (tp[nw] == 1) {
ans = ans + vals[nw];
ret += vals[prv].cnt[tp[v] ^ 2];
}
prv = nw;
nw = pr[nw];
}
return make_pair(ans, ret);
}
Node get_segm(int l, int r) {
Node ans;
for (int i = l; i < r; ++i) {
if (tp[ord[i]] == 1) {
ans = ans + vals[ord[i]];
}
}
return ans;
}
void change(int32_t x, int32_t y) {
if (tp[x] == y) {
return;
}
ans -= cnt_all[0] * 1LL * cnt_all[1] * 1LL * cnt_all[2];
cnt_all[tp[x]]--;
ans += cnt_all[0] * 1LL * cnt_all[1] * 1LL * cnt_all[2];
// cout << "? " << ans << endl;
if (tp[x] == 0 || tp[x] == 2) {
auto [node, ret] = get_to_root(x);
ans += ret;
auto node2 = get_segm(tin[0], tout[0]);
ans += (node2.cnt[1] - node.cnt[1]) * 1LL * cnt_all[tp[x] ^ 2] - (node2.sum_cnt[tp[x] ^ 2] - node.sum_cnt[tp[x] ^ 2]);
} else {
ans += sum[x];
if (!g[x].empty()) {
ans += vals[x].sum_heavy[0] * 1LL * vals[x].sum_heavy[2];
}
ans += (cnt_all[0] - vals[x].sum_cnt[0]) * 1LL * (cnt_all[2] - vals[x].sum_cnt[2]);
}
add_to_root(x, tp[x], -1);
tp[x] = y;
ans -= cnt_all[0] * 1LL * cnt_all[1] * 1LL * cnt_all[2];
cnt_all[tp[x]]++;
ans += cnt_all[0] * 1LL * cnt_all[1] * 1LL * cnt_all[2];
if (tp[x] == 0 || tp[x] == 2) {
auto [node, ret] = get_to_root(x);
ans -= ret;
auto node2 = get_segm(tin[0], tout[0]);
ans -= (node2.cnt[1] - node.cnt[1]) * 1LL * cnt_all[tp[x] ^ 2] - (node2.sum_cnt[tp[x] ^ 2] - node.sum_cnt[tp[x] ^ 2]);
} else {
// cout << "?" << cnt_all[0] << ' ' << vals[x].sum_cnt[0] << ' ' << cnt_all[2] << ' ' << vals[x].sum_cnt[2] << endl;
ans -= sum[x];
ans -= vals[x].sum_heavy[0] * 1LL * vals[x].sum_heavy[2];
ans -= (cnt_all[0] - vals[x].sum_cnt[0]) * 1LL * (cnt_all[2] - vals[x].sum_cnt[2]);
}
add_to_root(x, tp[x], 1);
}
long long num_tours() {
// cout << "!" << ans << "!" << endl;
return ans;
}