#include <algorithm>
#include <cstdio>
#include <vector>
#define debug(...) fprintf(stderr, __VA_ARGS__)
using LL = long long;
const int N = 3005;
const int P = 998244353;
int n, K, f[N];
std::pair<int, int> E[N][N];
int U[N], V[N], dp[N][N], nd[N], dis[N][N];
int fa[N], siz[N];
std::vector<int> G[N], c[N];
inline int find(int u) {
while (fa[u] != u) u = fa[u] = fa[fa[u]];
return u;
}
inline void merge(int u, int v) {
u = find(u), v = find(v);
if (u == v) return;
fa[u] = v, siz[v] += siz[u];
}
void dfs1(int *d, int u, int fa) {
for (auto v : G[u]) if (v != fa) {
d[v] = d[u] + 1;
dfs1(d, v, u);
}
}
inline bool comp(int u, int x, int y, int dx, int dy) {
return dis[u][x] < dis[u][y] || dis[u][x] == dis[u][y] && dx < dy;
}
void dfs2(int u, int fa) {
for (auto i : c[u]) dp[u][i] = 1;
for (auto v : G[u]) {
if (v == fa) continue;
dfs2(v, u);
for (auto i : c[u]) nd[i] = 0;
int x, y;
std::tie(x, y) = E[u][v];
for (auto j : c[v]) {
if (dp[v][j] == 0) continue;
for (auto i : c[u]) {
if (std::abs(dis[x][i] - dis[y][j]) > 1) continue;
if (comp(y, j, i, v, u) && comp(x, i, j, u, v))
(nd[i] += LL(dp[u][i]) * dp[v][j] % P) %= P;
}
}
for (auto i : c[u]) dp[u][i] = nd[i];
}
}
void solve() {
scanf("%d%d", &n, &K);
for (int i = 1; i <= K; ++i) c[i].clear();
for (int i = 1; i <= n; ++i) G[i].clear();
for (int i = 1; i < n; ++i) {
scanf("%d%d", &U[i], &V[i]);
G[U[i]].push_back(V[i]);
G[V[i]].push_back(U[i]);
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) dis[i][j] = 0;
dfs1(dis[i], i, 0);
}
for (int i = 1; i <= n; ++i) scanf("%d", &f[i]), c[f[i]].push_back(i);
for (int i = 1; i <= n; ++i) fa[i] = i, siz[i] = 1;
for (int i = 1; i < n; ++i)
if (f[U[i]] == f[V[i]]) merge(U[i], V[i]);
int cnt = 0;
for (int i = 1; i <= n; ++i) if (fa[i] == i) {
if (siz[i] != c[f[i]].size()) {
puts("0");
return;
}
++cnt;
}
if (cnt != K) return puts("0"), void();
for (int i = 1; i <= n; ++i) G[i].clear();
int ans = 0;
for (int i = 1; i < n; ++i) {
int u = f[U[i]], v = f[V[i]];
if (u != v) {
G[u].push_back(v);
G[v].push_back(u);
E[u][v] = {U[i], V[i]};
E[v][u] = {V[i], U[i]};
}
}
dfs2(f[1], 0);
for (auto i : c[f[1]]) (ans += dp[f[1]][i]) %= P;
printf("%d\n", ans);
}
int main() {
int t;
scanf("%d", &t);
while (t--) solve();
return 0;
}