#641946 | #4046. 钥匙 | hos_lyric | 100 ✓ | 1584ms | 152404kb | C++14 | 10.9kb | 2024-10-15 03:50:56 | 2024-10-15 03:50:57
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
using namespace std;
using Int = long long;
template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; ++i) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }
#define COLOR(s) ("\x1b[" s "m")
struct Hld {
int n, rt;
// needs to be tree
// vertex lists
// modified in build(rt) (parent removed, heavy child first)
vector<vector<int>> graph;
vector<int> sz, par, dep;
int zeit;
vector<int> dis, fin, sid;
// head vertex (minimum depth) in heavy path
vector<int> head;
Hld() : n(0), rt(-1), zeit(0) {}
explicit Hld(int n_) : n(n_), rt(-1), graph(n), zeit(0) {}
void ae(int u, int v) {
assert(0 <= u); assert(u < n);
assert(0 <= v); assert(v < n);
void dfsSz(int u) {
sz[u] = 1;
for (const int v : graph[u]) {
auto it = std::find(graph[v].begin(), graph[v].end(), u);
if (it != graph[v].end()) graph[v].erase(it);
par[v] = u;
dep[v] = dep[u] + 1;
sz[u] += sz[v];
void dfsHld(int u) {
dis[u] = zeit++;
const int deg = graph[u].size();
if (deg > 0) {
int vm = graph[u][0];
int jm = 0;
for (int j = 1; j < deg; ++j) {
const int v = graph[u][j];
if (sz[vm] < sz[v]) {
vm = v;
jm = j;
swap(graph[u][0], graph[u][jm]);
head[vm] = head[u];
for (int j = 1; j < deg; ++j) {
const int v = graph[u][j];
head[v] = v;
fin[u] = zeit;
void build(int rt_) {
assert(0 <= rt_); assert(rt_ < n);
rt = rt_;
sz.assign(n, 0);
par.assign(n, -1);
dep.assign(n, -1);
dep[rt] = 0;
zeit = 0;
dis.assign(n, -1);
fin.assign(n, -1);
head.assign(n, -1);
head[rt] = rt;
assert(zeit == n);
sid.assign(n, -1);
for (int u = 0; u < n; ++u) sid[dis[u]] = u;
friend ostream &operator<<(ostream &os, const Hld &hld) {
const int maxDep = *max_element(hld.dep.begin(), hld.dep.end());
vector<string> ss(2 * maxDep + 1);
int pos = 0, maxPos = 0;
for (int j = 0; j < hld.n; ++j) {
const int u = hld.sid[j];
const int d = hld.dep[u];
if (hld.head[u] == u) {
if (j != 0) {
pos = maxPos + 1;
ss[2 * d - 1].resize(pos, '-');
ss[2 * d - 1] += '+';
} else {
ss[2 * d - 1].resize(pos, ' ');
ss[2 * d - 1] += '|';
ss[2 * d].resize(pos, ' ');
ss[2 * d] += std::to_string(u);
if (maxPos < static_cast<int>(ss[2 * d].size())) {
maxPos = ss[2 * d].size();
for (int d = 0; d <= 2 * maxDep; ++d) os << ss[d] << '\n';
return os;
bool contains(int u, int v) const {
return (dis[u] <= dis[v] && dis[v] < fin[u]);
int lca(int u, int v) const {
assert(0 <= u); assert(u < n);
assert(0 <= v); assert(v < n);
for (; head[u] != head[v]; ) (dis[u] > dis[v]) ? (u = par[head[u]]) : (v = par[head[v]]);
return (dis[u] > dis[v]) ? v : u;
int jumpUp(int u, int d) const {
assert(0 <= u); assert(u < n);
assert(d >= 0);
if (dep[u] < d) return -1;
const int tar = dep[u] - d;
for (u = head[u]; ; u = head[par[u]]) {
if (dep[u] <= tar) return sid[dis[u] + (tar - dep[u])];
int jump(int u, int v, int d) const {
assert(0 <= u); assert(u < n);
assert(0 <= v); assert(v < n);
assert(d >= 0);
const int l = lca(u, v);
const int du = dep[u] - dep[l], dv = dep[v] - dep[l];
if (d <= du) {
return jumpUp(u, d);
} else if (d <= du + dv) {
return jumpUp(v, du + dv - d);
} else {
return -1;
// [u, v) or [u, v]
template <class F> void doPathUp(int u, int v, bool inclusive, F f) const {
assert(contains(v, u));
for (; head[u] != head[v]; u = par[head[u]]) f(dis[head[u]], dis[u] + 1);
if (inclusive) {
f(dis[v], dis[u] + 1);
} else {
if (v != u) f(dis[v] + 1, dis[u] + 1);
// not path order, include lca(u, v) or not
template <class F> void doPath(int u, int v, bool inclusive, F f) const {
const int l = lca(u, v);
doPathUp(u, l, false, f);
doPathUp(v, l, inclusive, f);
// (vs, ps): compressed tree
// vs: DFS order (sorted by dis)
// vs[ps[x]]: the parent of vs[x]
// ids[vs[x]] = x, not set for non-tree vertex
vector<int> ids;
pair<vector<int>, vector<int>> compress(vector<int> us) {
// O(n) first time
ids.resize(n, -1);
std::sort(us.begin(), us.end(), [&](int u, int v) -> bool {
return (dis[u] < dis[v]);
us.erase(std::unique(us.begin(), us.end()), us.end());
int usLen = us.size();
assert(usLen >= 1);
for (int x = 1; x < usLen; ++x) us.push_back(lca(us[x - 1], us[x]));
std::sort(us.begin(), us.end(), [&](int u, int v) -> bool {
return (dis[u] < dis[v]);
us.erase(std::unique(us.begin(), us.end()), us.end());
usLen = us.size();
for (int x = 0; x < usLen; ++x) ids[us[x]] = x;
vector<int> ps(usLen, -1);
for (int x = 1; x < usLen; ++x) ps[x] = ids[lca(us[x - 1], us[x])];
return make_pair(us, ps);
template <class X, class Y, class T> struct StaticRectAddPointSum {
struct Rect {
X x0, x1;
Y y0, y1;
vector<Rect> as;
vector<pair<X, Y>> bs;
vector<T> vals, anss;
// Adds val to [x0, x1) [y0, y1).
// ~~> Adds to (x*, y*)
void add(X x0, X x1, Y y0, Y y1, const T &val) {
assert(x0 <= x1); assert(y0 <= y1);
as.push_back(Rect{x0, x1, y0, y1});
// Gets sum at (x, y).
void get(X x, Y y) {
bs.emplace_back(x, y);
void run() {
const int asLen = as.size(), bsLen = bs.size();
// same x ==> add then get
vector<pair<X, int>> events((asLen << 1) + bsLen);
for (int i = 0; i < asLen; ++i) {
events[i << 1 ] = std::make_pair(as[i].x0, i << 1 );
events[i << 1 | 1] = std::make_pair(as[i].x1, i << 1 | 1);
for (int j = 0; j < bsLen; ++j) {
events[(asLen << 1) + j] = std::make_pair(bs[j].first, (asLen << 1) + j);
std::sort(events.begin(), events.end());
vector<Y> ys(bsLen);
for (int j = 0; j < bsLen; ++j) {
ys[j] = bs[j].second;
std::sort(ys.begin(), ys.end());
ys.erase(std::unique(ys.begin(), ys.end()), ys.end());
const int ysLen = ys.size();
vector<T> bit(ysLen, 0);
anss.assign(bsLen, 0);
for (const auto &event : events) {
if (event.second >= asLen << 1) {
const int j = event.second - (asLen << 1);
T sum = 0;
for (int l = std::lower_bound(ys.begin(), ys.end(), bs[j].second) - ys.begin() + 1; l > 0; l &= l - 1) {
sum += bit[l - 1];
anss[j] = sum;
} else {
const int i = event.second >> 1;
const T val = (event.second & 1) ? -vals[i] : vals[i];
for (int l = std::lower_bound(ys.begin(), ys.end(), as[i].y0) - ys.begin(); l < ysLen; l |= l + 1) {
bit[l] += val;
for (int l = std::lower_bound(ys.begin(), ys.end(), as[i].y1) - ys.begin(); l < ysLen; l |= l + 1) {
bit[l] -= val;
int N, Q;
vector<int> O, C;
vector<int> A, B;
vector<int> S, T;
Hld hld;
StaticRectAddPointSum<int, int, int> f;
// +1 point if S[q] -> s -> t -> T[q]
void add(int s, int t) {
const auto &L = hld.dis;
const auto &R = hld.fin;
assert(s != t);
if (hld.contains(s, t)) {
const int u = hld.jumpUp(t, hld.dep[t] - hld.dep[s] - 1);
f.add(0, L[u], L[t], R[t], +1);
f.add(R[u], N, L[t], R[t], +1);
} else if (hld.contains(t, s)) {
const int u = hld.jumpUp(s, hld.dep[s] - hld.dep[t] - 1);
f.add(L[s], R[s], 0, L[u], +1);
f.add(L[s], R[s], R[u], N, +1);
} else {
f.add(L[s], R[s], L[t], R[t], +1);
pair<vector<int>, vector<int>> vsps;
vector<vector<int>> graph;
void dfs(int c, int s, int x, int p, int d) {
const int u = vsps.first[x];
if (C[u] == c) {
if (O[u] == 1) {
} else {
if (--d == 0) {
// cerr<<"c = "<<c<<", s = "<<s<<", u = "<<u<<endl;
add(s, u);
for (const int y : graph[x]) if (p != y) {
dfs(c, s, y, x, d);
int main() {
for (; ~scanf("%d%d", &N, &Q); ) {
for (int u = 0; u < N; ++u) {
scanf("%d%d", &O[u], &C[u]);
A.resize(N - 1);
B.resize(N - 1);
for (int i = 0; i < N - 1; ++i) {
scanf("%d%d", &A[i], &B[i]);
for (int q = 0; q < Q; ++q) {
scanf("%d%d", &S[q], &T[q]);
hld = Hld(N);
for (int i = 0; i < N - 1; ++i) {
hld.ae(A[i], B[i]);
f = {};
vector<vector<int>> uss(N);
for (int u = 0; u < N; ++u) uss[C[u]].push_back(u);
for (int c = 0; c < N; ++c) if (uss[c].size()) {
vsps = hld.compress(uss[c]);
const int n = vsps.first.size();
graph.assign(n, {});
for (int y = 1; y < n; ++y) {
const int x = vsps.second[y];
for (const int u : uss[c]) if (O[u] == 1) {
dfs(c, u, hld.ids[u], -1, 0);
for (int q = 0; q < Q; ++q) {
f.get(hld.dis[S[q]], hld.dis[T[q]]);
for (int q = 0; q < Q; ++q) {
printf("%d\n", f.anss[q]);
return 0;
