QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#418996#8716. 树mrsunsWA 1ms5876kbC++206.6kb2024-05-23 16:52:472024-05-23 16:52:47

Judging History

你现在查看的是最新测评结果

  • [2024-05-23 16:52:47]
  • 评测
  • 测评结果:WA
  • 用时:1ms
  • 内存:5876kb
  • [2024-05-23 16:52:47]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
// #define max(a, b) ((a) > (b) ? (a) : (b))
// #define min(a, b) ((a) < (b) ? (a) : (b))
#define pb push_back
#define LNF 1e18
#define INF 0x7fffffff
#define int long long
#define lowbit(x) ((x) & (-x))
#define abs(x) llabs(x)
#define endl '\n'
#define Y() cout << "Yes" << endl
#define N() cout << "No" << endl
const db eps = 1e-9;
const int MOD = 1e9 + 7;
const int MAXN = 2e5 + 5;
int b[200005], add[200005], f[200005];
struct RMQ
{ // 求序列中区间最大?最小值的下标
    int n;
    vector<int> lg, nu;
    vector<array<int, 25>> dp; // 固定内部数组降低常数
    int get(const int &a, const int &b)
    {
        return nu[a] < nu[b] ? a : b; // 自定义比较函数,返回在原数组中较大的下标
    }
    RMQ() {}
    RMQ(vector<int> &v)
    { // v的有效下标从1开始,v.size()-1结束
        n = v.size() - 1;
        nu = v;
        lg.resize(n + 1);
        dp.resize(n + 1);
        lg[1] = 0;
        for (int i = 2; i <= n; i++)
        {
            lg[i] = lg[i >> 1] + 1;
        }
        for (int i = 1; i <= n; i++)
        {
            dp[i][0] = i;
        }
        for (int j = 1; j <= lg[n]; j++)
        {
            for (int i = 1; i <= n - (1 << j) + 1; i++)
            {
                dp[i][j] = get(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
            }
        }
    }
    int getidex(const int &l, const int &r) // 返回最值下标
    {
        int len = lg[r - l + 1];
        return get(dp[l][len], dp[r - (1 << len) + 1][len]);
    }
    int getnum(const int &l, const int &r) // 返回最值
    {
        return nu[getidex(l, r)];
    }
};
struct GRAPH
{
    vector<vector<pair<int, int>>> pth;
    vector<int> dep, first, nu, ref;
    int cnt = 0;
    GRAPH(int n)
    {
        pth.resize(n + 1);
        dep.resize(n + 1);
        first.resize(n + 1);
        nu.resize(n * 2 + 1);
        ref.resize(n * 2 + 1);
        cnt = 0; // 构造好了以后直接加边
    }
    void add_edge(int u, int v, int w = 1)
    { // 加双向边
        pth[u].push_back({v, w});
        pth[v].push_back({u, w});
    }
    void dfs(int pos, int fa)
    {
        first[pos] = ++cnt; // 第一次碰到这个点的dfs序
        nu[cnt] = dep[pos]; // 这个dfs序的深度
        ref[cnt] = pos;
        for (auto [to, len] : pth[pos])
        {
            if (to == fa)
                continue;
            dep[to] = dep[pos] + len; // 当前点深度
            dfs(to, pos);
            nu[++cnt] = dep[pos];
            ref[cnt] = pos;
        }
    };
    RMQ rmq;
    void prepare(int st)
    { // st是树的根节点
        dfs(st, 0);
        rmq = RMQ(nu);
    }
    int query_lca(int x, int y)
    { // 询问两点间的最近公共祖先
        int l, r;
        l = first[x], r = first[y];
        if (l > r)
            swap(l, r);
        int t = rmq.getidex(l, r);
        int lca = ref[t];
        return lca;
    };
    int query_dis(int x, int y)
    { // 询问两点间的距离
        int len = dep[x] + dep[y] - 2 * dep[query_lca(x, y)];
        return len;
    };
};
void solve()
{
    int n, m, q;
    cin >> n >> m >> q;
    GRAPH g(n + 1);
    for (int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        g.add_edge(u, v);
    }
    g.prepare(1);
    int ans = 2, fore = 0, flag = 0;
    for (int i = 1; i <= m; i++)
    {
        cin >> b[i];
    }
    auto cal = [&](int i) -> void
    {
        if (i == 2)
        {
            int lc = g.query_lca(b[2], b[1]);
            if (lc == b[2])
            {
                fore = 1;
            }
            else
            {
                fore = -1;
            }
            return;
        }
        f[i] = fore;
        int lca = g.query_lca(b[i], b[i - 1]);
        // cout << lca << endl;
        // cout << fore << " ";
        if (lca == b[i] || lca == b[i - 1])
        {
            // cout << i << endl;
            if (g.dep[b[i]] > g.dep[b[i - 1]])
            {
                flag = -1;
            }
            else
            {
                flag = 1;
            }
            if (fore != flag)
                ans++, add[i]++;
            fore = flag;

            // cout << ans << endl;
        }
        else
        {
            ans++, add[i]++;
            flag = 1;
            fore = -1;
        }
    };
    auto cal1 = [&](int i) -> void
    {
        if (i == 2)
        {
            int lc = g.query_lca(b[2], b[1]);
            if (lc == b[2])
            {
                fore = 1;
            }
            else
            {
                fore = -1;
            }
            f[i + 1] = fore;
            return;
        }
        add[i] = 0;
        fore = f[i];
        int lca = g.query_lca(b[i], b[i - 1]);
        // cout << lca << endl;
        // cout << fore << " ";
        if (lca == b[i] || lca == b[i - 1])
        {
            // cout << i << endl;
            if (g.dep[b[i]] > g.dep[b[i - 1]])
            {
                flag = -1;
            }
            else
            {
                flag = 1;
            }
            if (fore != flag)
                ans++, add[i]++;
            fore = flag;

            // cout << ans << endl;
        }
        else
        {
            ans++, add[i]++;
            flag = 1;
            fore = -1;
        }
        f[i + 1] = fore;
    };

    for (int i = 2; i <= m; i++)
    {
        cal(i);
        // cout << ans << endl;
    }
    // cout << endl;
    // f[m + 1] = fore;
    // if (f[m + 1] != 0)
    //     ans++;
    // for (int i = 1; i <= m; i++)
    //     cout << add[i] << " ";
    // cout << endl;
    // cout << ans << endl;
    while (q--)
    {
        int p, w;
        cin >> p >> w;
        if (m == 1 || m == 2)
        {
            cout << m << endl;
            continue;
        }
        ans -= (add[p] + add[p + 1] + add[p + 2] + add[p + 3]);
        // if (f[m + 1] != 0)
        //     ans--;
        b[p] = w;
        if (p != 1)
            cal1(p);

        if (p + 1 <= m)
            cal1(p + 1);
        if (p + 2 <= m)
            cal1(p + 2);
        if (p + 3 <= m)
            cal1(p + 3);
        cout << ans << endl;
    }
}
signed main()
{
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T = 1;
    // cin >> T;
    while (T--)
        solve();
    return 0;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1ms
memory: 5832kb

input:

5 5 3
2 1
3 2
1 4
5 1
1 5 4 2 3
1 3
5 3
3 3

output:

4
4
5

result:

ok 3 number(s): "4 4 5"

Test #2:

score: -100
Wrong Answer
time: 1ms
memory: 5876kb

input:

30 200 200
10 24
10 13
10 26
13 29
27 26
17 24
27 21
17 15
13 5
13 30
27 3
18 21
9 21
2 24
10 4
11 5
2 8
10 23
1 18
21 25
4 20
12 23
22 27
28 27
18 7
13 6
14 30
10 19
16 21
14 29 25 30 1 17 22 21 11 19 21 30 13 1 22 10 14 7 29 7 15 21 25 29 25 7 29 7 1 23 3 17 2 7 4 27 18 26 3 6 5 3 16 26 20 19 16 2...

output:

185
186
186
186
186
187
187
187
187
187
188
188
188
188
188
187
186
186
186
186
186
185
185
185
185
185
185
185
185
185
185
185
184
184
184
184
185
185
185
185
185
185
185
185
186
187
186
186
186
186
187
187
187
187
187
187
187
187
187
187
187
188
187
187
187
187
187
188
188
189
189
189
189
189
188
...

result:

wrong answer 1st numbers differ - expected: '174', found: '185'