QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#104658#6322. ForestryHgCl2WA 4ms19788kbC++145.5kb2023-05-11 16:15:142023-05-11 16:15:17

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2023-05-11 16:15:17]
  • 评测
  • 测评结果:WA
  • 用时:4ms
  • 内存:19788kb
  • [2023-05-11 16:15:14]
  • 提交

answer

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <numeric>
#include <string>
#include <utility>
#include <vector>

namespace {
using std::cin;
using std::cout;
using std::int64_t;
using std::size_t;

namespace base {
template <typename T, size_t... sizes>
struct NestedArray {};

template <typename T, size_t size, size_t... sizes>
struct NestedArray<T, size, sizes...> {
  using Type = std::array<typename NestedArray<T, sizes...>::Type, size>;
};

template <typename T>
struct NestedArray<T> {
  using Type = T;
};

template <typename T, size_t... sizes>
using Array = typename NestedArray<T, sizes...>::Type;

void OptimizeIO() {
  std::ios::sync_with_stdio(false);
  cin.tie(nullptr), cout.tie(nullptr);
}

void OptimizeIO(const std::string &input_file, const std::string &output_file) {
  static std::ifstream input_stream(input_file);
  static std::ofstream output_stream(output_file);
  cin.rdbuf(input_stream.rdbuf());
  cout.rdbuf(output_stream.rdbuf());
  cin.tie(nullptr), cout.tie(nullptr);
}
}  // namespace base

using base::Array;

const int kMod = 998244353, kInv2 = (kMod + 1) / 2;

namespace mod {
inline void Add(int &a, int b) {
  a += b;
  if (a >= kMod) a -= kMod;
}

inline int Sum(int a, int b) {
  a += b;
  if (a >= kMod) a -= kMod;
  return a;
}

inline void Sub(int &a, int b) {
  a -= b;
  if (a < 0) a += kMod;
}

inline int Diff(int &a, int b) {
  a -= b;
  if (a < 0) a += kMod;
  return a;
}

inline void Mul(int &a, int b) { a = static_cast<int64_t>(a) * b % kMod; }

inline int Prod(int a, int b) { return static_cast<int64_t>(a) * b % kMod; }

int Pow(int a, int b) {
  int ans = 1, prod = a;

  while (b) {
    if (b & 1) mod::Mul(ans, prod);
    mod::Mul(prod, prod), b >>= 1;
  }

  return ans;
}

inline int Inv(int a) { return Pow(a, kMod - 2); }
}  // namespace mod

const int kMaxN = 3.0e5 + 5;
int n;
Array<int, kMaxN> f, sum_f, fa;
Array<int, kMaxN, 2> ch;
Array<std::vector<int>, kMaxN> edge;

struct Node {
  int val, id;
};

inline bool operator<(const Node &a, const Node &b) { return a.val < b.val; }

Array<Node, kMaxN> a;

struct Matrix {
  int a11, a13, a21, a23;
  Matrix() = default;
  Matrix(int a11, int a13, int a21, int a23)
      : a11(a11), a13(a13), a21(a21), a23(a23) {}
};

inline Matrix operator*(const Matrix &a, const Matrix &b) {
  return Matrix(mod::Prod(a.a11, b.a11),
                (static_cast<int64_t>(a.a11) * b.a13 + a.a13) % kMod,
                mod::Prod(a.a21, b.a11),
                (static_cast<int64_t>(a.a21) * b.a13 + a.a23) % kMod);
}

Array<Matrix, kMaxN> val, sum;

Matrix GenMatrix(int f, int sum_f) {
  int half = mod::Prod(f, kInv2);
  return {half, half, mod::Sum(1, half), mod::Sum(sum_f, half)};
}

inline bool IsRoot(int p) { return p != ch[fa[p]][0] && p != ch[fa[p]][1]; }

inline int Get(int p) { return p == ch[fa[p]][1]; }

void PushUp(int p) {
  int x = ch[p][0], y = ch[p][1];

  if (!x && !y) {
    sum[p] = val[p];
  } else if (!x) {
    sum[p] = val[p] * sum[y];
  } else if (!y) {
    sum[p] = sum[x] * val[p];
  } else {
    sum[p] = sum[x] * val[p] * sum[y];
  }
}

void Rotate(int p) {
  int q = fa[p], k = Get(p);
  if (!IsRoot(q)) ch[fa[q]][Get(q)] = p;
  fa[p] = fa[q];
  ch[q][k] = ch[p][k ^ 1], fa[ch[p][k ^ 1]] = q;
  ch[p][k ^ 1] = q, fa[q] = p;
  PushUp(q), PushUp(p);
}

void Splay(int p) {
  while (!IsRoot(p)) {
    int q = fa[p];
    if (!IsRoot(q)) Get(p) == Get(q) ? Rotate(q) : Rotate(p);
    Rotate(p);
  }
}

inline std::pair<int, int> Query(int p) {
  return {mod::Sum(sum[p].a11, sum[p].a13), mod::Sum(sum[p].a21, sum[p].a23)};
}

void Access(int p) {
  Splay(p);
  int sum_f = mod::Diff(val[p].a23, val[p].a11);
  if (ch[p][1]) mod::Add(sum_f, Query(ch[p][1]).first);
  val[p] = {0, 0, 1, sum_f}, ch[p][1] = 0;
  PushUp(p);
  int q = p;
  p = fa[p];

  while (p) {
    Splay(p);
    int f = mod::Sum(val[p].a11, val[p].a11);
    int sum_f = mod::Diff(val[p].a23, val[p].a11);
    auto tmp = Query(q);
    mod::Mul(f, mod::Inv(tmp.first)), mod::Sub(f, tmp.second);

    if (ch[p][1]) {
      tmp = Query(ch[p][1]);
      mod::Mul(f, tmp.first), mod::Add(sum_f, tmp.second);
    }

    val[p] = GenMatrix(f, sum_f);
    ch[p][1] = q, PushUp(p);
    q = p, p = fa[p];
  }
}

void Dfs(int u, int fa) {
  ::fa[u] = fa;
  f[u] = 1;

  for (int v : edge[u]) {
    if (v == fa) continue;
    Dfs(v, u);
    mod::Mul(f[u], mod::Prod(f[v] + 1, kInv2));
    mod::Add(sum_f[u], sum_f[v]);
  }

  val[u] = GenMatrix(f[u], sum_f[u]);
  PushUp(u);
  mod::Add(sum_f[u], f[u]);
}

int Main() {
  base::OptimizeIO();
  cin >> n;

  for (int i = 1; i <= n; i++) {
    cin >> a[i].val;
    a[i].id = i;
  }

  std::sort(a.begin() + 1, a.begin() + n + 1);

  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    edge[u].emplace_back(v), edge[v].emplace_back(u);
  }

  Dfs(1, 0);
  int ans = 0, p = 1;
  std::pair<int, int> res = Query(1);
  mod::Add(ans, mod::Prod(a[1].val, res.first + res.second));

  while (p < n) {
    int i = p;
    while (i < n && a[i].val == a[p].val) Access(a[i++].id);
    std::pair<int, int> res = Query(1);
    mod::Add(ans, mod::Prod(a[i].val - a[p].val, res.first + res.second));
    p = i;
  }

  for (int i = 1; i <= n - 2; i++) mod::Add(ans, ans);
  cout << ans << "\n";
  return 0;
}
}  // namespace

int main() { return Main(); }

详细

Test #1:

score: 0
Wrong Answer
time: 4ms
memory: 19788kb

input:

4
1 2 3 4
1 2
2 4
3 2

output:

32

result:

wrong answer 1st numbers differ - expected: '44', found: '32'