QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#705189 | #4548. Rock Tree | TheZone | AC ✓ | 1706ms | 190508kb | C++14 | 6.4kb | 2024-11-02 22:28:36 | 2024-11-02 22:28:37 |
Judging History
answer
#include <bits/stdc++.h>
#define f first
#define s second
#define pb push_back
#define mp make_pair
#define int long long
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;
const int N = 100500, inf = 1e9, mod = 998244353;
const ll INF = 1e18;
int sum(int a, int b) {
a += b;
if (a >= mod)
a -= mod;
return a;
}
int sub(int a, int b) {
a -= b;
if (a < 0)
a += mod;
return a;
}
int mult(int a, int b) {
return 1ll * a * b % mod;
}
int bp(int a, int b) {
int res = 1;
while (b) {
if (b & 1)
res = mult(res, a);
a = mult(a, a);
b >>= 1;
}
return res;
}
int inv(int x) {
return bp(x, mod - 2);
}
int n, k, a[N], h[N], mh[N], ans;
vector<int> g[N];
const int M = 4e6;
struct node {
int a, b;
int l, r;
int mx;
node() {
a = b = 0;
l = r = -1;
mx = 0;
}
} t[M];
int t_n;
int dp[N];
int nn() {
t[t_n] = node();
while (t_n == M - 5);
return t_n++;
}
void push(int v) {
if (t[v].l == -1)
t[v].l = nn();
if (t[v].r == -1)
t[v].r = nn();
for (int u: {t[v].l, t[v].r}) {
t[u].a += t[v].a;
t[u].b += t[v].a;
t[u].mx += t[v].a;
t[u].b = max(t[u].b, t[v].b);
t[u].mx = max(t[u].mx, t[v].b);
}
t[v].a = 0;
t[v].b = 0;
}
void pull(int v) {
t[v].mx = max(t[t[v].l].mx, t[t[v].r].mx);
}
void upd(int v, int tl, int tr, int l, int r, int tp, int x) {
if (r < tl || tr < l || l > r)
return;
if (l <= tl && tr <= r) {
if (tp == 0) {
t[v].a += x;
t[v].b += x;
t[v].mx += x;
}
if (tp == 1) {
t[v].b = max(t[v].b, x);
t[v].mx = max(t[v].mx, x);
}
return;
}
push(v);
int tm = (tl + tr) >> 1;
upd(t[v].l, tl, tm, l, r, tp, x);
upd(t[v].r, tm + 1, tr, l, r, tp, x);
pull(v);
}
void upd_add(int v, int l, int r, int x) {
upd(v, 0, n - 1, l, r, 0, x);
}
void upd_max(int v, int l, int r, int x) {
upd(v, 0, n - 1, l, r, 1, x);
}
int get(int v, int tl, int tr, int p) {
if (tl == tr)
return t[v].mx;
push(v);
int tm = (tl + tr) >> 1;
if (p <= tm)
return get(t[v].l, tl, tm, p);
else
return get(t[v].r, tm + 1, tr, p);
}
int get_val(int v, int p) {
return get(v, 0, n - 1, p);
}
int A[N], B[N], C[N];
void dfs(int v, int p) {
mh[v] = h[v];
int u = -1;
for (auto to: g[v]) {
if (to == p)
continue;
h[to] = h[v] + 1;
dfs(to, v);
mh[v] = max(mh[v], mh[to]);
if (u == -1 || mh[u] < mh[to])
u = to;
}
if (u == -1) {
dp[v] = nn();
upd_add(dp[v], h[v], h[v], a[v]);
} else {
dp[v] = dp[u];
for (auto to: g[v]) {
if (to == p || to == u)
continue;
int to_mx = mh[to] - h[v];
if (to_mx < k / 3 && k >= 2) {
int x = 0;
for (int i = h[v]; i <= mh[to]; i++) {
int d = i - h[v];
x = max(x, get_val(dp[to], h[v] + d));
upd_add(dp[v], h[v] + d, h[v] + d, x);
int rd = k - d;
if (h[v] + rd <= mh[v])
upd_add(dp[v], h[v] + rd, h[v] + rd, x);
if (i == mh[to])
upd_add(dp[v], h[v] + d + 1, min(mh[v], h[v] + rd - 1), x);
}
int lst = 0;
for (int i = h[v]; i <= mh[to]; i++) {
int d = to_mx - i + h[v];
int pos = min(mh[v], h[v] + k) - d;
if (pos < h[v])
break;
int x = get_val(dp[v], pos);
lst = max(lst, x);
upd_max(dp[v], pos, pos, lst);
}
} else {
for (int i = 0; i <= k; i++) {
if (h[v] + i <= mh[v])
A[i] = get_val(dp[v], h[v] + i);
else
A[i] = 0;
if (h[v] + i <= mh[to])
B[i] = get_val(dp[to], h[v] + i);
else
B[i] = 0;
}
for (int i = 1; i <= k; i++) {
A[i] = max(A[i], A[i - 1]);
B[i] = max(B[i], B[i - 1]);
}
for (int i = 0; i <= k; i++) {
C[i] = max(
A[i] + B[min(i, k - i)],
B[i] + A[min(i, k - i)]
);
}
for (int i = 1; i <= k; i++) {
C[i] = max(C[i], C[i - 1]);
}
for (int i = 0; i <= k; i++) {
if (h[v] + i <= mh[v])
upd_max(dp[v], h[v] + i, h[v] + i, C[i]);
}
}
}
upd_add(dp[v], h[v], min(h[v] + k, mh[v]), a[v]);
}
ans = max(ans, t[dp[v]].mx);
if (h[v] + k <= mh[v])
upd_add(dp[v], h[v] + k, h[v] + k, -inf);
if (h[v] + k <= mh[v])
upd_add(dp[v], h[v] + k, h[v] + k, -1e12);
}
mt19937 rnd(228);
int dist[55][55];
void solve() {
cin >> n >> k;
k--;
// cerr << n << " " << k << endl;
for (int i = 0; i < n; i++) {
cin >> a[i];
}
for (int i = 0; i < n - 1; i++) {
int v, u;
cin >> v >> u;
// v = i + 2;
// u = rnd() % (i + 1) + 1;
v--, u--;
// cerr << v << " " << u << endl;
g[v].pb(u);
g[u].pb(v);
// dist[u][v] = dist[v][u] = 1;
}
ans = 0;
dfs(0, -1);
if (ans == 0)
ans = *max_element(a, a + n);
// cerr << stupid << " " << ans << endl;
// assert(stupid == ans);
cout << ans << endl;
for (int i = 0; i < n; i++)
g[i].clear();
t_n = 0;
}
signed main() {
#ifdef DEBUG
freopen("input.txt", "r", stdin);
#endif
ios_base::sync_with_stdio(false);
int t = 1000;
cin >> t;
for (int i = 1; i <= t; i++) {
// cout << "Case #" << i << endl;
solve();
}
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 1706ms
memory: 190508kb
input:
88 49707 15234 -53 -7 34 -79 25 -63 -3 58 -60 -29 -64 -51 81 -45 -22 73 -46 7 -17 10 24 -81 -75 85 -19 88 46 12 0 -87 21 -88 -71 -2 61 50 24 48 -48 -67 46 43 87 59 -60 97 71 19 -36 91 54 73 25 -62 -92 74 10 100 52 -4 -11 65 89 65 -100 -79 77 -53 41 5 65 -47 77 20 -25 0 5 10 82 -21 27 31 91 -85 -57 -...
output:
1539829 47120 1779436 9475 100 2015 1166766 2833267 61582773 34428 186218 7915 62876367 83732 24766 9992 486 1799544 -1 7966 6266 9012 5770 1151949 7258 399 5526 24745 8213 119391577 11 7810 8851 7288 16694 8546 768 1 12759 1252 6510 1607629 231818575 6869 27986 11151 11221 199 4587 1410036 28210 12...
result:
ok 88 lines