QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#104685 | #6322. Forestry | HgCl2 | WA | 1ms | 21756kb | C++14 | 5.7kb | 2023-05-11 17:27:34 | 2023-05-11 17:27:37 |
Judging History
answer
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
namespace {
using std::cin;
using std::cout;
using std::int64_t;
using std::size_t;
namespace base {
template <typename T, size_t... sizes>
struct NestedArray {};
template <typename T, size_t size, size_t... sizes>
struct NestedArray<T, size, sizes...> {
using Type = std::array<typename NestedArray<T, sizes...>::Type, size>;
};
template <typename T>
struct NestedArray<T> {
using Type = T;
};
template <typename T, size_t... sizes>
using Array = typename NestedArray<T, sizes...>::Type;
void OptimizeIO() {
std::ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
}
void OptimizeIO(const std::string &input_file, const std::string &output_file) {
static std::ifstream input_stream(input_file);
static std::ofstream output_stream(output_file);
cin.rdbuf(input_stream.rdbuf());
cout.rdbuf(output_stream.rdbuf());
cin.tie(nullptr), cout.tie(nullptr);
}
} // namespace base
using base::Array;
const int kMod = 998244353, kInv2 = (kMod + 1) / 2;
namespace mod {
inline void Add(int &a, int b) {
a += b;
if (a >= kMod) a -= kMod;
}
inline int Sum(int a, int b) {
a += b;
if (a >= kMod) a -= kMod;
return a;
}
inline void Sub(int &a, int b) {
a -= b;
if (a < 0) a += kMod;
}
inline int Diff(int a, int b) {
a -= b;
if (a < 0) a += kMod;
return a;
}
inline void Mul(int &a, int b) { a = static_cast<int64_t>(a) * b % kMod; }
inline int Prod(int a, int b) { return static_cast<int64_t>(a) * b % kMod; }
int Pow(int a, int b) {
int ans = 1, prod = a;
while (b) {
if (b & 1) mod::Mul(ans, prod);
mod::Mul(prod, prod), b >>= 1;
}
return ans;
}
inline int Inv(int a) { return Pow(a, kMod - 2); }
} // namespace mod
const int kMaxN = 3.0e5 + 5;
int n;
Array<int, kMaxN> f, sum_f, fa;
Array<int, kMaxN, 2> ch;
Array<bool, kMaxN> vis;
Array<std::vector<int>, kMaxN> edge;
struct Node {
int val, id;
};
inline bool operator<(const Node &a, const Node &b) { return a.val < b.val; }
Array<Node, kMaxN> a;
struct Matrix {
int a11, a13, a21, a23;
Matrix() = default;
Matrix(int a11, int a13, int a21, int a23)
: a11(a11), a13(a13), a21(a21), a23(a23) {}
};
inline Matrix operator*(const Matrix &a, const Matrix &b) {
return Matrix(mod::Prod(a.a11, b.a11),
(static_cast<int64_t>(a.a11) * b.a13 + a.a13) % kMod,
mod::Prod(a.a21, b.a11),
(static_cast<int64_t>(a.a21) * b.a13 + a.a23) % kMod);
}
Array<Matrix, kMaxN> val, sum;
struct Info {
int f, sum_f;
};
inline Info operator+(const Info &a, const Info &b) {
return {mod::Prod(a.f, b.f), mod::Sum(a.sum_f, b.sum_f)};
}
inline Info operator-(const Info &a, const Info &b) {
if (a.f && b.f != 1) {
return {mod::Prod(a.f, mod::Inv(b.f)), mod::Diff(a.sum_f, b.sum_f)};
}
return {a.f, mod::Diff(a.sum_f, b.sum_f)};
}
Matrix GenMatrix(Info info, bool f) {
int half = mod::Prod(info.f, kInv2);
return {half, half, mod::Sum(1, half), mod::Sum(info.sum_f, half)};
}
inline bool IsRoot(int p) { return p != ch[fa[p]][0] && p != ch[fa[p]][1]; }
inline int Get(int p) { return p == ch[fa[p]][1]; }
void PushUp(int p) {
int x = ch[p][0], y = ch[p][1];
if (!x && !y) {
sum[p] = val[p];
} else if (!x) {
sum[p] = val[p] * sum[y];
} else if (!y) {
sum[p] = sum[x] * val[p];
} else {
sum[p] = sum[x] * val[p] * sum[y];
}
}
void Rotate(int p) {
int q = fa[p], k = Get(p);
if (!IsRoot(q)) ch[fa[q]][Get(q)] = p;
fa[p] = fa[q];
ch[q][k] = ch[p][k ^ 1], fa[ch[p][k ^ 1]] = q;
ch[p][k ^ 1] = q, fa[q] = p;
PushUp(q), PushUp(p);
}
void Splay(int p) {
while (!IsRoot(p)) {
int q = fa[p];
if (!IsRoot(q)) Get(p) == Get(q) ? Rotate(q) : Rotate(p);
Rotate(p);
}
}
inline Info Query(int p) {
return {mod::Sum(sum[p].a11, sum[p].a13), mod::Sum(sum[p].a21, sum[p].a23)};
}
inline Info Retrodict(int p, bool f) {
if (f) {
return {0, val[p].a23};
} else {
return {mod::Sum(val[p].a11, val[p].a11),
mod::Diff(val[p].a23, val[p].a11)};
}
}
void Access(int p) {
int tmp = p;
vis[p] = true;
int q = 0;
while (p) {
Splay(p);
Info info = Retrodict(p, p == tmp ? false : true);
if (ch[p][1]) info = info + Query(ch[p][1]);
info = info - Query(q);
val[p] = GenMatrix(info, vis[p]), ch[p][1] = q, PushUp(p);
q = p, p = fa[p];
}
}
void Dfs(int u, int fa) {
::fa[u] = fa;
f[u] = 1;
for (int v : edge[u]) {
if (v == fa) continue;
Dfs(v, u);
mod::Mul(f[u], mod::Prod(f[v] + 1, kInv2));
mod::Add(sum_f[u], sum_f[v]);
}
val[u] = GenMatrix({f[u], sum_f[u]}, false);
PushUp(u);
mod::Add(sum_f[u], f[u]);
}
int Main() {
base::OptimizeIO();
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i].val;
a[i].id = i;
}
std::sort(a.begin() + 1, a.begin() + n + 1);
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
edge[u].emplace_back(v), edge[v].emplace_back(u);
}
Dfs(1, 0);
int ans = 0, p = 1;
mod::Add(ans, mod::Prod(a[1].val, f[1] + sum_f[1]));
while (p < n) {
int i = p;
while (i < n && a[i].val == a[p].val) Access(a[i++].id);
Info res = Query(1);
mod::Add(ans, mod::Prod(a[i].val - a[p].val, res.f + res.sum_f));
p = i;
}
for (int i = 1; i <= n - 2; i++) mod::Add(ans, ans);
cout << ans << "\n";
return 0;
}
} // namespace
int main() { return Main(); }
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 1ms
memory: 21756kb
input:
4 1 2 3 4 1 2 2 4 3 2
output:
40
result:
wrong answer 1st numbers differ - expected: '44', found: '40'