QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#729549#9570. Binary Treeucup-team4435#RE 0ms3628kbC++204.3kb2024-11-09 17:20:202024-11-09 17:20:20

Judging History

你现在查看的是最新测评结果

  • [2024-11-09 17:20:20]
  • 评测
  • 测评结果:RE
  • 用时:0ms
  • 内存:3628kb
  • [2024-11-09 17:20:20]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;

#define all(a) begin(a), end(a)
#define len(a) int((a).size())

template<typename Fun>
struct y_combinator {
    const Fun fun;

    explicit y_combinator(const Fun&& fun) : fun(std::forward<const Fun>(fun)) {}

    template<typename... Args>
    auto operator()(Args&&... args) const {
        return fun(std::ref(*this), std::forward<Args>(args)...);
    }
};

void solve(int /* test_num */) {
    int n;
    cin >> n;
    vector<vector<int>> g(n);
    for (int i = 0, l, r; i < n; i++) {
        cin >> l >> r;
        l--, r--;
        if (l != -1) {
            g[l].push_back(i);
            g[i].push_back(l);
        }
        if (r != -1) {
            g[r].push_back(i);
            g[i].push_back(r);
        }
    }

    auto query = [&](int v, int u) {
        cout << "? " << v + 1 << ' ' << u + 1 << endl;
        int res;
        cin >> res;
        return res;
    };

    vector<bool> cands(n, true);
    vector<int> sz(n);

    auto init_sizes = [&](int root) {
        y_combinator([&](auto dfs, int v, int p) -> void {
            assert(cands[v]);
            sz[v] = 1;
            for (auto u : g[v]) {
                if (u != p && cands[u]) {
                    dfs(u, v);
                    sz[v] += sz[u];
                }
            }
        })(root, -1);
    };

    while (count(all(cands), true) > 1) {
        int root = find(all(cands), true) - cands.begin();
        assert(root < n);

        init_sizes(root);
        int centroid = root, par = -1;

        while (true) {
            bool found = false;
            for (auto u : g[centroid]) {
                if (cands[u] && u != par && sz[u] * 2 >= n) {
                    found = true;
                    par = centroid;
                    centroid = u;
                    break;
                }
            }

            if (!found) {
                break;
            }
        }

        root = centroid;
        init_sizes(root);

        vector<pair<int, int>> childs;
        for (auto u : g[root]) {
            if (cands[u]) {
                childs.emplace_back(sz[u], u);
            }
        }
        sort(all(childs));

        vector<int> ncands;

        auto find_ncands = y_combinator([&](auto dfs, int v, int p) -> void {
            assert(cands[v]);
            ncands.push_back(v);
            for (auto u : g[v]) {
                if (cands[u] && u != p) {
                    dfs(u, v);
                }
            }
        });

        if (len(childs) == 3) {
            int a = childs[0].second, b = childs[1].second, c = childs[2].second;
            int ret = query(b, c);

            if (ret == 0) {
                find_ncands(b, root);
            } else if (ret == 1) {
                find_ncands(a, root);
                ncands.push_back(root);
            } else {
                assert(ret == 2);
                find_ncands(c, root);
            }
        } else if (len(childs) == 2) {
            int a = childs[0].second, b = childs[1].second;
            int ret = query(a, b);

            if (ret == 0) {
                find_ncands(a, root);
            } else if (ret == 1) {
                ncands = {root};
            } else {
                assert(ret == 2);
                find_ncands(b, root);
            }
        } else if (len(childs) == 1) {
            assert(count(all(cands), true) == 2);
            int a = root, b = childs[0].second;
            auto ret = query(a, b);

            if (ret == 0) {
                ncands = {a};
            } else {
                assert(ret == 2);
                ncands = {b};
            }
        } else {
            assert(false);
        }

        fill(all(cands), false);
        for (auto v : ncands) {
            cands[v] = true;
        }
    }

    assert(count(all(cands), true) == 1);
    int ans = 0;
    while (!cands[ans]) {
        ans++;
    }
    cout << "! " << ans + 1 << endl;
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);

    int tests;
    cin >> tests;
    for (int test_num = 1; test_num <= tests; test_num++) {
        solve(test_num);
    }
}

詳細信息

Test #1:

score: 100
Accepted
time: 0ms
memory: 3628kb

input:

2
5
0 0
1 5
2 4
0 0
0 0
2
0
2
0 2
0 0
2

output:

? 5 3
? 3 4
! 3
? 2 1
! 1

result:

ok OK (2 test cases)

Test #2:

score: -100
Runtime Error

input:

5555
8
2 0
8 6
0 0
3 0
0 0
7 0
0 0
5 4
2

output:

? 4 2

result: