QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#197419 | #5372. 杀蚂蚁简单版 | Benzenesir | 0 | 0ms | 0kb | C++14 | 3.4kb | 2023-10-02 15:45:07 | 2023-10-02 15:45:07 |
answer
#include <cstdio>
#include <cmath>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>
#include <unordered_map>
#include <set>
#include <bitset>
#include <stack>
#include <tuple>
#include <bitset>
#define ll long long
#define ull unsigned long long
#define ld long double
#define db double
#define fp(a, b, c) for (int a = b; a <= c; a++)
#define fd(a, b, c) for (int a = b; a >= c; a--)
#define pii pair<int, int>
#define inf 0x3f3f3f3f
#define base 127
#define mod 998244353
#define eb emplace_back
#define y1 y114
#define y0 y514
#define x1 x114
#define x0 x514
#define mpr make_pair
#define met(x, t) memset(x, t, sizeof(x))
#define fir first
#define sec second
#include <numeric>
#include <stdlib.h>
#include <assert.h>
using namespace std;
inline int rd() {
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
return x * f;
}
const int N = 1e5 + 10;
int n, q;
int acc[20][N], dep[N], fa[N];
ll val[N], c[3][N], tot[N];
vector<int> g[N];
ll inv(ll x, ll p = mod - 2) {
ll ans = 1;
for (; p; p >>= 1) {
if (p & 1)
(ans *= x) %= mod;
(x *= x) %= mod;
}
return ans;
}
void dfs(int now, int f) {
acc[0][now] = f, dep[now] = dep[f] + 1;
for (int u = acc[0][now], i = 0; u; u = acc[++i][now]) acc[i + 1][now] = acc[i][u];
ll h1 = tot[now] * val[now] % mod, h2 = inv(val[now] * val[f] % mod);
c[0][now] = (c[0][f] + h1) % mod;
c[1][now] = (c[1][f] + h2) % mod;
c[2][now] = (c[2][f] + h1 * c[1][now] % mod) % mod;
for (int x : g[now])
if (x ^ f)
dfs(x, now);
}
int lca(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int gap = dep[x] - dep[y], bit = __lg(gap); gap; gap -= (1 << bit), bit = __lg(gap)) x = acc[bit][x];
fd(i, __lg(dep[x]), 0) if (acc[i][x] ^ acc[i][y]) x = acc[i][x], y = acc[i][y];
return (x ^ y) ? acc[0][x] : x;
}
ll calc(int x, int y, int d) {
int z = lca(x, y);
ll res = (c[d][x] + c[d][y]) % mod;
res = (res + mod - c[d][z]) % mod;
res = (res + mod - c[d][acc[0][z]]) % mod;
return res;
}
signed main() {
freopen("ant.in", "r", stdin);
freopen("ant.out", "w", stdout);
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
n = rd();
fp(i, 1, n) val[i] = rd();
fp(i, 1, n - 1) {
int u = rd(), v = rd();
(tot[u] += val[v]) %= mod;
(tot[v] += val[u]) %= mod;
g[u].emplace_back(v), g[v].emplace_back(u);
}
dfs(2, 1);
q = rd();
while (q--) {
int s = rd(), x = rd(), y = rd();
int h = lca(x, y), h1 = lca(x, s), h2 = lca(y, s);
ll ans = 0;
if (h1 == h2) {
ans = calc(x, y, 0) * calc(h1, 2, 1) % mod;
} else {
if (h == h1)
swap(x, y);
h1 = lca(x, s);
ans = (calc(x, h1, 0) + mod) % mod * calc(h1, 2, 1) % mod;
(ans += (calc(y, h, 0) + mod - calc(h, h, 0)) % mod * calc(h, 2, 1) % mod) %= mod;
(ans += calc(acc[0][h1], h, 2)) %= mod;
}
cout << ans << endl;
}
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Subtask #1:
score: 0
Dangerous Syscalls
Test #1:
score: 0
Dangerous Syscalls
input:
5 1 1 1 1 1 1 2 2 3 2 4 3 5 1 2 4 2
output:
result:
Subtask #2:
score: 0
Skipped
Dependency #1:
0%
Subtask #3:
score: 0
Dangerous Syscalls
Test #18:
score: 0
Dangerous Syscalls
input:
100000 13643 13546 7538 2233 7731 14619 19601 8438 9556 19888 17313 1060 15168 11207 11183 16074 10758 7469 13444 9658 18326 4735 7542 13836 5863 7903 7212 14714 10416 18506 13435 14502 15271 13205 14887 18074 8353 19807 1767 19148 7343 10823 14211 66 17168 8305 1210 5436 18552 3659 886 18416 19261 ...
output:
result:
Subtask #4:
score: 0
Skipped
Dependency #3:
0%
Subtask #5:
score: 0
Skipped
Dependency #2:
0%
Subtask #6:
score: 0
Skipped
Dependency #1:
0%