QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#705189#4548. Rock TreeTheZoneAC ✓1706ms190508kbC++146.4kb2024-11-02 22:28:362024-11-02 22:28:37

Judging History

你现在查看的是最新测评结果

  • [2024-11-02 22:28:37]
  • 评测
  • 测评结果:AC
  • 用时:1706ms
  • 内存:190508kb
  • [2024-11-02 22:28:36]
  • 提交

answer

#include <bits/stdc++.h>

#define f first
#define s second
#define pb push_back
#define mp make_pair
#define int long long

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;

const int N = 100500, inf = 1e9, mod = 998244353;
const ll INF = 1e18;

int sum(int a, int b) {
    a += b;
    if (a >= mod)
        a -= mod;
    return a;
}

int sub(int a, int b) {
    a -= b;
    if (a < 0)
        a += mod;
    return a;
}

int mult(int a, int b) {
    return 1ll * a * b % mod;
}

int bp(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1)
            res = mult(res, a);
        a = mult(a, a);
        b >>= 1;
    }
    return res;
}

int inv(int x) {
    return bp(x, mod - 2);
}

int n, k, a[N], h[N], mh[N], ans;
vector<int> g[N];


const int M = 4e6;

struct node {
    int a, b;
    int l, r;
    int mx;

    node() {
        a = b = 0;
        l = r = -1;
        mx = 0;
    }
} t[M];

int t_n;
int dp[N];


int nn() {
    t[t_n] = node();
    while (t_n == M - 5);
    return t_n++;
}

void push(int v) {
    if (t[v].l == -1)
        t[v].l = nn();
    if (t[v].r == -1)
        t[v].r = nn();
    for (int u: {t[v].l, t[v].r}) {
        t[u].a += t[v].a;
        t[u].b += t[v].a;
        t[u].mx += t[v].a;
        t[u].b = max(t[u].b, t[v].b);
        t[u].mx = max(t[u].mx, t[v].b);
    }
    t[v].a = 0;
    t[v].b = 0;
}

void pull(int v) {
    t[v].mx = max(t[t[v].l].mx, t[t[v].r].mx);
}

void upd(int v, int tl, int tr, int l, int r, int tp, int x) {
    if (r < tl || tr < l || l > r)
        return;
    if (l <= tl && tr <= r) {
        if (tp == 0) {
            t[v].a += x;
            t[v].b += x;
            t[v].mx += x;
        }
        if (tp == 1) {
            t[v].b = max(t[v].b, x);
            t[v].mx = max(t[v].mx, x);
        }
        return;
    }
    push(v);
    int tm = (tl + tr) >> 1;
    upd(t[v].l, tl, tm, l, r, tp, x);
    upd(t[v].r, tm + 1, tr, l, r, tp, x);
    pull(v);
}

void upd_add(int v, int l, int r, int x) {
    upd(v, 0, n - 1, l, r, 0, x);
}

void upd_max(int v, int l, int r, int x) {
    upd(v, 0, n - 1, l, r, 1, x);
}

int get(int v, int tl, int tr, int p) {
    if (tl == tr)
        return t[v].mx;
    push(v);
    int tm = (tl + tr) >> 1;
    if (p <= tm)
        return get(t[v].l, tl, tm, p);
    else
        return get(t[v].r, tm + 1, tr, p);
}

int get_val(int v, int p) {
    return get(v, 0, n - 1, p);
}

int A[N], B[N], C[N];

void dfs(int v, int p) {
    mh[v] = h[v];
    int u = -1;
    for (auto to: g[v]) {
        if (to == p)
            continue;
        h[to] = h[v] + 1;
        dfs(to, v);
        mh[v] = max(mh[v], mh[to]);
        if (u == -1 || mh[u] < mh[to])
            u = to;
    }
    if (u == -1) {
        dp[v] = nn();
        upd_add(dp[v], h[v], h[v], a[v]);
    } else {
        dp[v] = dp[u];
        for (auto to: g[v]) {
            if (to == p || to == u)
                continue;
            int to_mx = mh[to] - h[v];
            if (to_mx < k / 3 && k >= 2) {
                int x = 0;
                for (int i = h[v]; i <= mh[to]; i++) {
                    int d = i - h[v];
                    x = max(x, get_val(dp[to], h[v] + d));
                    upd_add(dp[v], h[v] + d, h[v] + d, x);
                    int rd = k - d;
                    if (h[v] + rd <= mh[v])
                        upd_add(dp[v], h[v] + rd, h[v] + rd, x);
                    if (i == mh[to])
                        upd_add(dp[v], h[v] + d + 1, min(mh[v], h[v] + rd - 1), x);
                }
                int lst = 0;
                for (int i = h[v]; i <= mh[to]; i++) {
                    int d = to_mx - i + h[v];
                    int pos = min(mh[v], h[v] + k) - d;
                    if (pos < h[v])
                        break;
                    int x = get_val(dp[v], pos);
                    lst = max(lst, x);
                    upd_max(dp[v], pos, pos, lst);
                }

            } else {
                for (int i = 0; i <= k; i++) {
                    if (h[v] + i <= mh[v])
                        A[i] = get_val(dp[v], h[v] + i);
                    else
                        A[i] = 0;
                    if (h[v] + i <= mh[to])
                        B[i] = get_val(dp[to], h[v] + i);
                    else
                        B[i] = 0;
                }
                for (int i = 1; i <= k; i++) {
                    A[i] = max(A[i], A[i - 1]);
                    B[i] = max(B[i], B[i - 1]);
                }
                for (int i = 0; i <= k; i++) {
                    C[i] = max(
                            A[i] + B[min(i, k - i)],
                            B[i] + A[min(i, k - i)]
                    );
                }
                for (int i = 1; i <= k; i++) {
                    C[i] = max(C[i], C[i - 1]);
                }
                for (int i = 0; i <= k; i++) {
                    if (h[v] + i <= mh[v])
                        upd_max(dp[v], h[v] + i, h[v] + i, C[i]);
                }
            }
        }
        upd_add(dp[v], h[v], min(h[v] + k, mh[v]), a[v]);
    }
    ans = max(ans, t[dp[v]].mx);
    if (h[v] + k <= mh[v])
        upd_add(dp[v], h[v] + k, h[v] + k, -inf);
    if (h[v] + k <= mh[v])
        upd_add(dp[v], h[v] + k, h[v] + k, -1e12);
}

mt19937 rnd(228);
int dist[55][55];

void solve() {
    cin >> n >> k;
    k--;
//    cerr << n << " " << k << endl;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    for (int i = 0; i < n - 1; i++) {
        int v, u;
        cin >> v >> u;
//        v = i + 2;
//        u = rnd() % (i + 1) + 1;
        v--, u--;
//        cerr << v << " " << u << endl;
        g[v].pb(u);
        g[u].pb(v);
//        dist[u][v] = dist[v][u] = 1;
    }
    ans = 0;
    dfs(0, -1);
    if (ans == 0)
        ans = *max_element(a, a + n);
//    cerr << stupid << " " << ans << endl;
//    assert(stupid == ans);
    cout << ans << endl;
    for (int i = 0; i < n; i++)
        g[i].clear();
    t_n = 0;
}

signed main() {
#ifdef DEBUG
    freopen("input.txt", "r", stdin);
#endif
    ios_base::sync_with_stdio(false);
    int t = 1000;
    cin >> t;
    for (int i = 1; i <= t; i++) {
//        cout << "Case #" << i << endl;
        solve();
    }
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1706ms
memory: 190508kb

input:

88
49707 15234
-53 -7 34 -79 25 -63 -3 58 -60 -29 -64 -51 81 -45 -22 73 -46 7 -17 10 24 -81 -75 85 -19 88 46 12 0 -87 21 -88 -71 -2 61 50 24 48 -48 -67 46 43 87 59 -60 97 71 19 -36 91 54 73 25 -62 -92 74 10 100 52 -4 -11 65 89 65 -100 -79 77 -53 41 5 65 -47 77 20 -25 0 5 10 82 -21 27 31 91 -85 -57 -...

output:

1539829
47120
1779436
9475
100
2015
1166766
2833267
61582773
34428
186218
7915
62876367
83732
24766
9992
486
1799544
-1
7966
6266
9012
5770
1151949
7258
399
5526
24745
8213
119391577
11
7810
8851
7288
16694
8546
768
1
12759
1252
6510
1607629
231818575
6869
27986
11151
11221
199
4587
1410036
28210
12...

result:

ok 88 lines