#include <iostream>
// #include <cassert>
#include <set>
#define int long long
using namespace std;
int n, m;
int a[200005];
int head[200005], nxt[400005], to[400005], ecnt;
void add(int u, int v) { to[++ecnt] = v, nxt[ecnt] = head[u], head[u] = ecnt; }
int top[200005], dep[200005], son[200005], sz[200005], f[200005];
int dfn[200005], _dfn[200005], ncnt;
void dfs1(int x, int fa, int d) {
dep[x] = d;
f[x] = fa;
sz[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v != fa) {
dfs1(v, x, d + 1);
sz[x] += sz[v];
if (sz[v] > sz[son[x]])
son[x] = v;
}
}
}
void dfs2(int x, int t) {
top[x] = t;
_dfn[dfn[x] = ++ncnt] = x;
if (!son[x])
return;
dfs2(son[x], t);
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v != f[x] && v != son[x])
dfs2(v, v);
}
}
int LCA(int x, int y) {
while (top[x] ^ top[y]) (dep[top[x]] < dep[top[y]]) ? (y = f[top[y]]) : (x = f[top[x]]);
return (dep[x] < dep[y] ? x : y);
}
int dist(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
set<int> st;
int cur = 0;
void Add(int x) {
if (x == 1)
return;
int a, b;
int d = dfn[x];
set<int>::iterator it = st.upper_bound(d);
b = (it != st.end() ? _dfn[*it] : _dfn[*st.begin()]);
assert(it != st.begin());
--it;
a = _dfn[*it];
cur -= dist(a, b);
cur += dist(a, x);
cur += dist(b, x);
st.insert(dfn[x]);
}
void Erase(int x) {
if (x == 1)
return;
int a, b;
int d = dfn[x];
st.erase(st.find(d));
set<int>::iterator it = st.upper_bound(d);
b = (it != st.end() ? _dfn[*it] : _dfn[*st.begin()]);
assert(it != st.begin());
--it;
a = _dfn[*it];
cur += dist(a, b);
cur -= dist(a, x);
cur -= dist(b, x);
}
int dp[505][30005];
int _l = 1, _r;
void Move(int tl, int tr) {
while (_r < tr) Add(a[++_r]);
while (_r > tr) Erase(a[_r--]);
while (_l < tl) Erase(a[_l++]);
while (_l > tl) Add(a[--_l]);
}
void Solve(int l, int r, int L, int R, int k) {
if (l > r)
return;
int mid = (l + r) >> 1;
int p = L;
for (int i = L; i <= mid; i++) {
Move(i, mid);
if (dp[k - 1][i - 1] + cur > dp[k][mid]) {
dp[k][mid] = dp[k - 1][i - 1] + cur;
p = i;
}
}
// cout << k << " " << mid << " " << p << " " << dp[k][mid] << "\n";
Solve(l, mid - 1, L, p, k);
Solve(mid + 1, r, p, R, k);
}
signed main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs1(1, 0, 1);
dfs2(1, 1);
st.insert(1);
for (int i = 1; i <= n; i++) {
Add(a[i]);
dp[1][i] = cur;
// cout << i << " " << dp[1][i] << "\n";
}
for (int i = 1; i <= n; i++) Erase(a[i]);
for (int i = 2; i <= m; i++) Solve(i, n, i, n, i);
cout << dp[m][n] << "\n";
return 0;
}