QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#129829#4548. Rock Treebatrr#RE 0ms0kbC++175.9kb2023-07-23 01:36:222023-07-23 01:36:24

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2023-07-23 01:36:24]
  • 评测
  • 测评结果:RE
  • 用时:0ms
  • 内存:0kb
  • [2023-07-23 01:36:22]
  • 提交

answer

#include <bits/stdc++.h>

#define f first
#define s second
#define pb push_back
#define mp make_pair

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;

const int N = 300500, inf = 1e9, mod = 998244353;
const ll INF = 1e18;

int sum(int a, int b) {
    a += b;
    if (a >= mod)
        a -= mod;
    return a;
}

int sub(int a, int b) {
    a -= b;
    if (a < 0)
        a += mod;
    return a;
}

int mult(int a, int b) {
    return 1ll * a * b % mod;
}

int bp(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1)
            res = mult(res, a);
        a = mult(a, a);
        b >>= 1;
    }
    return res;
}

int inv(int x) {
    return bp(x, mod - 2);
}

int n, k, a[N], h[N], mh[N], ans;
vector<int> g[N];


const int M = 2e6;

struct node {
    int a, b;
    int l, r;
    int mx;

    node() {
        a = b = 0;
        l = r = -1;
        mx = 0;
    }
} t[M];

int t_n;
int dp[N];


int nn() {
    t[t_n] = node();
    return t_n++;
}

void push(int v) {
    if (t[v].l == -1)
        t[v].l = nn();
    if (t[v].r == -1)
        t[v].r = nn();
    for (int u: {t[v].l, t[v].r}) {
        t[u].a += t[v].a;
        t[u].b += t[v].a;
        t[u].mx += t[v].a;
        t[u].b = max(t[u].b, t[v].b);
        t[u].mx = max(t[u].mx, t[v].b);
    }
    t[v].a = 0;
    t[v].b = 0;
}

void pull(int v) {
    t[v].mx = max(t[t[v].l].mx, t[t[v].r].mx);
}

void upd(int v, int tl, int tr, int l, int r, int tp, int x) {
    if (r < tl || tr < l || l > r)
        return;
    if (l <= tl && tr <= r) {
        if (tp == 0) {
            t[v].a += x;
            t[v].b += x;
            t[v].mx += x;
        }
        if (tp == 1) {
            t[v].b = max(t[v].b, x);
            t[v].mx = max(t[v].mx, x);
        }
        return;
    }
    push(v);
    int tm = (tl + tr) >> 1;
    upd(t[v].l, tl, tm, l, r, tp, x);
    upd(t[v].r, tm + 1, tr, l, r, tp, x);
    pull(v);
}

void upd_add(int v, int l, int r, int x) {
    upd(v, 0, n - 1, l, r, 0, x);
}

void upd_max(int v, int l, int r, int x) {
    upd(v, 0, n - 1, l, r, 1, x);
}

int get(int v, int tl, int tr, int p) {
    if (tl == tr)
        return t[v].mx;
    push(v);
    int tm = (tl + tr) >> 1;
    if (p <= tm)
        return get(t[v].l, tl, tm, p);
    else
        return get(t[v].r, tm + 1, tr, p);
}

int get_val(int v, int p) {
    return get(v, 0, n - 1, p);
}

int A[N], B[N], C[N];

void dfs(int v, int p) {
    mh[v] = h[v];
    int u = -1;
    for (auto to: g[v]) {
        if (to == p)
            continue;
        h[to] = h[v] + 1;
        dfs(to, v);
        mh[v] = max(mh[v], mh[to]);
        if (u == -1 || mh[u] < mh[to])
            u = to;
    }
    if (u == -1) {
        dp[v] = nn();
        upd_add(dp[v], h[v], h[v], a[v]);
    } else {
        dp[v] = dp[u];
        for (auto to: g[v]) {
            if (to == p || to == u)
                continue;
            int to_mx = mh[to] - h[v];
            if (to_mx < k - to_mx) {
                for (int i = h[v]; i <= mh[to]; i++) {
                    int d = i - h[v];
                    int x = get_val(dp[to], h[v] + d);
                    upd_add(dp[v], h[v] + d, h[v] + d, x);
                    int rd = k - d;
                    if (h[v] + rd <= mh[v])
                        upd_add(dp[v], h[v] + rd, h[v] + rd, x);
                    if (i == mh[to])
                        upd_add(dp[v], h[v] + d + 1, h[v] + rd - 1, x);
                }
                int lst = 0;
                for (int i = h[v]; i <= mh[to]; i++) {
                    int d = i - h[v];
                    int pos = mh[v] - d;
                    if (pos < h[v])
                        break;
                    int x = get_val(dp[v], pos);
                    lst = max(lst, x);
                    upd_max(dp[v], pos, pos, lst);
                }
            } else {
                for (int i = 0; i <= k; i++) {
                    if (h[v] + i <= mh[v])
                        A[i] = get_val(dp[v], h[v] + i);
                    else
                        A[i] = 0;
                    if (h[v] + i <= mh[to])
                        B[i] = get_val(dp[to], h[v] + i);
                    else
                        B[i] = 0;
                }
                for (int i = 1; i <= k; i++) {
                    A[i] = max(A[i], A[i - 1]);
                    B[i] = max(B[i], B[i - 1]);
                }
                for (int i = 0; i <= k; i++) {
                    C[i] = max(
                            A[i] + B[min(i, k - i)],
                            B[i] + A[min(i, k - i)]
                    );
                }
                for (int i = 0; i <= k; i++) {
                    if (h[v] + i <= mh[v])
                        upd_add(dp[v], h[v] + i, h[v] + i, C[i] - A[i]);
                }
            }
        }
        upd_add(dp[v], h[v], mh[v], a[v]);
    }
    ans = max(ans, t[dp[v]].mx);
//    for(int i = 0; i <= k && h[v] + i <= mh[v]; i++)
//        cerr << get_val(dp[v], h[v] + i) << " ";
//    cerr << endl;
    if(h[v] + k <= mh[v])
        upd_add(dp[v], h[v] + k, mh[v] + k, -inf);
}

void solve() {
    cin >> n >> k;
    k--;
    for (int i = 0; i < n; i++)
        cin >> a[i];
    for (int i = 0; i < n - 1; i++) {
        int v, u;
        cin >> v >> u;
        v--, u--;
        g[v].pb(u);
        g[u].pb(v);
    }
    ans = 0;
    dfs(0, -1);
    if (ans == 0)
        ans = *max_element(a, a + n);
    cout << ans << "\n";
    for (int i = 0; i < n; i++)
        g[i].clear();
    t_n = 0;
}

int main() {
#ifdef DEBUG
    freopen("input.txt", "r", stdin);
#endif
    ios_base::sync_with_stdio(false);
    int t = 1;
    cin >> t;
    for (int i = 1; i <= t; i++) {
//        cout << "Case #" << i << endl;
        solve();
    }
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Runtime Error

input:

88
49707 15234
-53 -7 34 -79 25 -63 -3 58 -60 -29 -64 -51 81 -45 -22 73 -46 7 -17 10 24 -81 -75 85 -19 88 46 12 0 -87 21 -88 -71 -2 61 50 24 48 -48 -67 46 43 87 59 -60 97 71 19 -36 91 54 73 25 -62 -92 74 10 100 52 -4 -11 65 89 65 -100 -79 77 -53 41 5 65 -47 77 20 -25 0 5 10 82 -21 27 31 91 -85 -57 -...

output:

1571104
1295009126
1779449
9475
100
3197
1174277

result: