QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#104702 | #6322. Forestry | HgCl2 | WA | 1ms | 11916kb | C++14 | 5.6kb | 2023-05-11 18:45:19 | 2023-05-11 18:45:22 |
Judging History
answer
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
namespace {
using std::cin;
using std::cout;
using std::int64_t;
using std::size_t;
namespace base {
template <typename T, size_t... sizes>
struct NestedArray {};
template <typename T, size_t size, size_t... sizes>
struct NestedArray<T, size, sizes...> {
using Type = std::array<typename NestedArray<T, sizes...>::Type, size>;
};
template <typename T>
struct NestedArray<T> {
using Type = T;
};
template <typename T, size_t... sizes>
using Array = typename NestedArray<T, sizes...>::Type;
void OptimizeIO() {
std::ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
}
void OptimizeIO(const std::string &input_file, const std::string &output_file) {
static std::ifstream input_stream(input_file);
static std::ofstream output_stream(output_file);
cin.rdbuf(input_stream.rdbuf());
cout.rdbuf(output_stream.rdbuf());
cin.tie(nullptr), cout.tie(nullptr);
}
} // namespace base
using base::Array;
const int kMod = 998244353, kInv2 = (kMod + 1) / 2;
namespace mod {
inline void Add(int &a, int b) {
a += b;
if (a >= kMod) a -= kMod;
}
inline int Sum(int a, int b) {
a += b;
if (a >= kMod) a -= kMod;
return a;
}
inline void Sub(int &a, int b) {
a -= b;
if (a < 0) a += kMod;
}
inline int Diff(int a, int b) {
a -= b;
if (a < 0) a += kMod;
return a;
}
inline void Mul(int &a, int b) { a = static_cast<int64_t>(a) * b % kMod; }
inline int Prod(int a, int b) { return static_cast<int64_t>(a) * b % kMod; }
int Pow(int a, int b) {
int ans = 1, prod = a;
while (b) {
if (b & 1) Mul(ans, prod);
Mul(prod, prod), b >>= 1;
}
return ans;
}
inline int Inv(int a) { return Pow(a, kMod - 2); }
} // namespace mod
const int kMaxN = 3.0e5 + 5;
int n;
Array<int, kMaxN> f, sum_f, fa;
Array<int, kMaxN, 2> ch;
Array<bool, kMaxN> vis;
Array<std::vector<int>, kMaxN> edge;
struct Node {
int val, id;
};
inline bool operator<(const Node &a, const Node &b) { return a.val < b.val; }
Array<Node, kMaxN> a;
struct Matrix {
int a11, a13, a21, a23;
};
inline Matrix operator*(const Matrix &a, const Matrix &b) {
return {
mod::Prod(a.a11, b.a11),
static_cast<int>((static_cast<int64_t>(a.a11) * b.a13 + a.a13) % kMod),
mod::Prod(a.a21, b.a11),
static_cast<int>((static_cast<int64_t>(a.a21) * b.a13 + a.a23))};
}
Array<Matrix, kMaxN> val, sum;
struct Info {
int g, h;
};
inline Info operator+(const Info &a, const Info &b) {
return {mod::Prod(a.g, b.g), mod::Sum(a.h, b.h)};
}
inline Info operator-(const Info &a, const Info &b) {
if (a.g == 0 || b.h == 1) return {a.g, mod::Diff(a.h, b.h)};
return {mod::Prod(a.g, mod::Inv(b.g)), mod::Diff(a.h, b.h)};
}
Matrix GenMatrix(const Info &info, bool f) {
if (f) return {0, 0, 1, info.h};
int half = mod::Prod(info.g, kInv2);
return {half, half, mod::Sum(half, 1), mod::Sum(half, info.h)};
}
Info Retrodict(const Matrix &a, bool f) {
if (f) return {0, a.a23};
int g = mod::Sum(a.a11, a.a11);
return {g, mod::Diff(a.a23, g)};
}
void Dfs(int u, int fa) {
f[u] = 1;
for (int v : edge[u]) {
if (v == fa) continue;
Dfs(v, u);
mod::Mul(f[u], mod::Prod(f[v] + 1, kInv2));
mod::Add(sum_f[u], sum_f[v]);
}
::fa[u] = fa;
sum[u] = val[u] = GenMatrix({f[u], sum_f[u]}, false);
mod::Add(sum_f[u], f[u]);
}
inline bool IsRoot(int p) { return p != ch[fa[p]][0] && p != ch[fa[p]][1]; }
inline int Get(int p) { return p == ch[fa[p]][1]; }
inline void PushUp(int p) {
int x = ch[p][0], y = ch[p][1];
if (!x && !y) {
sum[p] = val[p];
} else if (!x) {
sum[p] = val[p] * sum[y];
} else if (!y) {
sum[p] = sum[x] * val[p];
} else {
sum[p] = sum[x] * val[p] * sum[y];
}
}
void Rotate(int p) {
int q = fa[p], k = Get(p);
if (!IsRoot(q)) ch[fa[q]][Get(q)] = p;
fa[p] = fa[q];
ch[q][k] = ch[p][k ^ 1], fa[ch[p][k ^ 1]] = q;
ch[p][k ^ 1] = q, fa[q] = p;
PushUp(q), PushUp(p);
}
void Splay(int p) {
while (!IsRoot(p)) {
int q = fa[p];
if (!IsRoot(q)) Get(p) == Get(q) ? Rotate(q) : Rotate(p);
Rotate(p);
}
}
inline Info Query(int p) {
return {mod::Sum(sum[p].a11, sum[p].a13), mod::Sum(sum[p].a21, sum[p].a23)};
}
void Access(int p) {
int q = 0;
while (p) {
Splay(p);
Info info = Retrodict(val[p], vis[p]);
if (ch[p][1]) info = info + Query(ch[p][1]);
if (q) info = info - Query(q);
ch[p][1] = q, PushUp(p);
q = p, p = fa[p];
}
}
void Modify(int p) {
Access(p), Splay(p);
vis[p] = true;
Info info = Retrodict(val[p], false);
mod::Sub(info.h, Query(p).g);
val[p] = GenMatrix(info, true);
PushUp(p);
}
int Main() {
base::OptimizeIO();
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i].val;
a[i].id = i;
}
std::sort(a.begin() + 1, a.begin() + n + 1);
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
edge[u].emplace_back(v), edge[v].emplace_back(u);
}
Dfs(1, 0);
int ans = 0, p = 1;
mod::Add(ans, mod::Prod(a[1].val, f[1] + sum_f[1]));
while (p < n) {
int i = p;
while (i < n && a[i].val == a[p].val) Modify(a[i++].id);
Info res = Query(1);
mod::Add(ans, mod::Prod(a[i].val - a[p].val, res.g + res.h));
p = i;
}
mod::Mul(ans, mod::Pow(2, n - 2));
cout << ans << "\n";
return 0;
}
} // namespace
int main() { return Main(); }
详细
Test #1:
score: 0
Wrong Answer
time: 1ms
memory: 11916kb
input:
4 1 2 3 4 1 2 2 4 3 2
output:
50
result:
wrong answer 1st numbers differ - expected: '44', found: '50'