QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#326946#7997. 树 V 图ftiaschWA 0ms3824kbC++2311.3kb2024-02-14 15:14:072024-02-14 15:14:07

Judging History

你现在查看的是最新测评结果

  • [2024-02-14 15:14:07]
  • 评测
  • 测评结果:WA
  • 用时:0ms
  • 内存:3824kb
  • [2024-02-14 15:14:07]
  • 提交

answer

#include <bits/stdc++.h>
#include <experimental/type_traits>
// {{{ boilerplate

template <template <typename...> class Template, typename T>
struct is_specialization_of : std::false_type {};

template <template <typename...> class Template, typename... Args>
struct is_specialization_of<Template, Template<Args...>> : std::true_type {};

template <template <typename...> class Template, typename T>
inline constexpr bool is_specialization_of_v =
    is_specialization_of<Template, T>::value;

#ifndef YES
#define YES "Yes"
#endif
#ifndef NO
#define NO "No"
#endif

template <typename IO> struct IOBaseT {
  template <typename T = int> T read(T &&v = T{}) {
    using DecayedT = std::decay_t<T>;
    if constexpr (is_specialization_of_v<std::tuple, DecayedT>) {
      read_t_(v, std::make_index_sequence<std::tuple_size_v<DecayedT>>());
    } else if constexpr (is_vector_like<DecayedT>()) {
      for (auto it = v.begin(); it != v.end(); it++) {
        read(*it);
      }
    } else {
      static_cast<IO *>(this)->template read1(std::forward<T>(v));
    }
    return v;
  }

  template <typename T> IOBaseT &operator<<(const T &o) {
    if constexpr (std::is_same_v<bool, T>) {
      return static_cast<IO *>(this)->write1(o ? YES : NO), *this;
    } else if constexpr (is_vector_like<T>()) {
      bool first = true;
      for (auto it = o.begin(); it != o.end(); it++) {
        if (first) {
          first = false;
        } else {
          static_cast<IO *>(this)->template write1(' ');
        }
        static_cast<IO *>(this)->template write1(*it);
      }
      return *this;
    } else {
      return static_cast<IO *>(this)->template write1(o), *this;
    }
  }

  // helper

  template <typename T = int> std::vector<T> read_v(int n) {
    return read(std::vector<T>(n));
  }

  template <typename... Ts> std::tuple<Ts...> read_t() {
    return read(std::tuple<Ts...>{});
  }

private:
  template <typename T> using has_begin_t = decltype(std::declval<T>().begin());

  template <typename T> static constexpr bool is_vector_like() {
    return !std::is_same_v<T, std::string> &&
           std::experimental::is_detected_v<has_begin_t, T>;
  }

  template <typename Tuple, std::size_t... Index>
  void read_t_(Tuple &t, std::index_sequence<Index...>) {
    (..., (std::get<Index>(t) = read(std::tuple_element_t<Index, Tuple>{})));
  }
};

struct IO : public IOBaseT<IO> {
  friend class IOBaseT<IO>;

  explicit IO(bool sync = false) {
    if (!sync) {
      std::ios::sync_with_stdio(false);
      std::cin.tie(nullptr);
    }
  }

private:
  template <typename T> void read1(T &&v) { std::cin >> v; }
  template <typename T> void write1(T &&v) { std::cout << v; }
};

template <typename T, typename N = uint64_t>
static constexpr T binpow(T a, N n) {
  static_assert(std::is_integral_v<N>);
  auto result = T::mul_id();
  while (n) {
    if (n & 1) {
      result *= a;
    }
    a *= a;
    n >>= 1;
  }
  return result;
}

template <class T> struct Singleton {
  static_assert(std::is_default_constructible_v<T>);

  static T &instance() {
    static T store;
    return store;
  }
};

template <typename T> static inline T &singleton() {
  return Singleton<T>::instance();
}

namespace mod {

template <typename M> struct MultiplierT {};

template <> struct MultiplierT<uint32_t> {
  using M2 = uint64_t;

  static constexpr int LOG_M_BITS = 5;

  static constexpr M2 mul_hi(M2 x, M2 y) {
    return static_cast<__uint128_t>(x) * static_cast<__uint128_t>(y) >> 64;
  }
};

template <> struct MultiplierT<uint64_t> {
  using M2 = __uint128_t;

  static constexpr int LOG_M_BITS = 6;

  static constexpr M2 mul_hi(M2 x, M2 y) {
    M2 x_lo = x & UINT64_MAX;
    M2 x_hi = x >> 64;
    M2 y_lo = y & UINT64_MAX;
    M2 y_hi = y >> 64;
    M2 lo_lo = x_lo * y_lo;
    M2 hi_lo = x_hi * y_lo;
    M2 lo_hi = x_lo * y_hi;
    M2 cy = (hi_lo & UINT64_MAX) + (lo_hi & UINT64_MAX) + (lo_lo >> 64);
    return x_hi * y_hi + (hi_lo >> 64) + (lo_hi >> 64) + (cy >> 64);
  }
};

template <typename M> using m2_t = typename MultiplierT<M>::M2;

template <typename Mod> struct ModWrapperT {
  using M = typename Mod::M;
  using M2 = m2_t<M>;

private:
  // traits
  template <typename T> using has_static_mod_t = decltype(T::MOD);
  static constexpr bool has_static_mod =
      std::experimental::is_detected_v<has_static_mod_t, Mod>;

  // Montgomery performs arith in the Montgomery domain
  template <typename T> using has_wrap_t = decltype(T::wrap(std::declval<M>()));
  static constexpr bool has_wrap =
      std::experimental::is_detected_v<has_wrap_t, Mod>;

public:
  static void set_mod(M mod) {
    if constexpr (!has_static_mod) {
      singleton<Mod>().set_mod(mod);
    }
  }

  static constexpr M mod() {
    if constexpr (has_static_mod) {
      return Mod::MOD;
    } else {
      return singleton<Mod>().get_mod();
    }
  }

  static constexpr ModWrapperT mul_id() { return ModWrapperT::normalize(1); }

  static constexpr ModWrapperT neg_id(uint64_t n) {
    return (n & 1) ? -ModWrapperT{1} : ModWrapperT{1};
  }

  static constexpr ModWrapperT normalize(M2 x) {
    return ModWrapperT{static_cast<M>(x % mod())};
  }

  constexpr ModWrapperT() : x{construct(0)} {}

  template <typename T = M>
  explicit constexpr ModWrapperT(T x_ = 0) : x{construct(static_cast<M>(x_))} {
    static_assert(std::numeric_limits<T>::digits <=
                  std::numeric_limits<M>::digits);
  }

  constexpr M get() const {
    if constexpr (has_wrap) {
      return Mod::unwrap(x);
    } else {
      return x;
    }
  }

  constexpr bool operator==(const ModWrapperT &other) const {
    return get() == other.get();
  }

  constexpr ModWrapperT &operator+=(const ModWrapperT &other) {
    if constexpr (has_wrap) {
      Mod::add(x, other.x);
    } else {
      x += other.x;
      if (x >= mod()) {
        x -= mod();
      }
    }
    return *this;
  }

  constexpr ModWrapperT &operator-=(const ModWrapperT &other) {
    if constexpr (has_wrap) {
      Mod::sub(x, other.x);
    } else {
      x += mod() - other.x;
      if (x >= mod()) {
        x -= mod();
      }
    }
    return *this;
  }

  constexpr ModWrapperT operator*=(const ModWrapperT &other) {
    auto p = static_cast<M2>(x) * static_cast<M2>(other.x);
    if constexpr (has_static_mod) {
      x = Mod::reduce(p);
    } else {
      x = singleton<Mod>().reduce(p);
    }
    return *this;
  }

  constexpr ModWrapperT operator/=(const ModWrapperT &other) {
    return *this *= other.inv();
  }

  constexpr ModWrapperT inv() const { return binpow(*this, mod() - 2); }

  // helper arith

  constexpr bool operator!=(const ModWrapperT &other) const {
    return !(*this == other);
  }

  constexpr ModWrapperT operator+(const ModWrapperT &other) const {
    ModWrapperT copy = *this;
    return copy += other;
  }

  constexpr ModWrapperT operator-() const {
    ModWrapperT copy{0};
    copy -= *this;
    return copy;
  }

  constexpr ModWrapperT operator-(const ModWrapperT &other) const {
    ModWrapperT copy = *this;
    return copy -= other;
  }

  constexpr ModWrapperT operator*(const ModWrapperT &other) const {
    ModWrapperT copy = *this;
    return copy *= other;
  }

  constexpr ModWrapperT operator/(const ModWrapperT &other) const {
    ModWrapperT copy = *this;
    return copy /= other;
  }

private:
  static constexpr M construct(M x) {
    if constexpr (has_wrap) {
      return Mod::wrap(x);
    } else {
      return x;
    }
  }

  M x;
};

} // namespace mod

namespace std {

template <typename Mod>
ostream &operator<<(ostream &out, const mod::ModWrapperT<Mod> &w) {
  return out << w.get();
}

} // namespace std

namespace mod {

template <typename M_, M_ MOD_> struct ModBaseT {
  using M = M_;
  static constexpr M MOD = MOD_;

  static_assert((MOD - 1) <= (std::numeric_limits<M_>::max() >> 1));

private:
  using M2 = m2_t<M>;

public:
  static constexpr M reduce(M2 x) { return x % MOD; }
};

template <uint64_t M> using Mod64T = ModWrapperT<ModBaseT<uint64_t, M>>;
template <uint32_t M> using ModT = ModWrapperT<ModBaseT<uint32_t, M>>;

} // namespace mod

using mod::Mod64T;
using mod::ModT;

namespace ranges = std::ranges;
namespace views = std::views;

// }}}

using Mod = ModT<998'244'353>;

constexpr int N = 3000;

int n, m, belong[N], head[N], to[N << 1], next[N << 1], depth[N], parent[N],
    dsu_parent[N], leader_[N], dist_[N][N];
bool visit[N];
Mod dp[N][N];

namespace ct /* component tree */ {
int head[N], next[N];
}
namespace gr /* grouping */ {
int head[N], next[N];
}

int dsu_find(int u) {
  if (dsu_parent[u] != u) {
    dsu_parent[u] = dsu_find(dsu_parent[u]);
  }
  return dsu_parent[u];
}

int &leader(int u) { return leader_[belong[u]]; }

bool prepare(int u) {
  if (!~leader(u)) {
    leader(u) = u;
  }
  auto U = leader(u);
  if (U != dsu_find(u)) {
    return false;
  }
  gr::next[u] = gr::head[U];
  gr::head[U] = u;
  for (auto it = head[u]; ~it; it = next[it]) {
    auto v = to[it];
    if (v != parent[u]) {
      depth[v] = u;
      parent[v] = u;
      if (belong[u] == belong[v]) {
        dsu_parent[v] = dsu_parent[u];
      }
      if (!prepare(v)) {
        return false;
      }
      auto V = leader(v);
      if (U != V) {
        ct::next[V] = ct::head[U];
        ct::head[U] = V;
      }
    }
  }
  return true;
}

int dist(int u, int v) {
  if (depth[u] > depth[v]) {
    return dist(v, u);
  }
  auto &cache = dist_[u][v];
  if (!~cache) {
    cache = dist(u, parent[v]) + 1;
  }
  return cache;
}

std::vector<int> children;

void dfs(int u) {
  for (auto v = ct::head[u]; ~v; v = ct::next[v]) {
    dfs(v);
  }
  std::fill(dp[u], dp[u] + n, Mod{0});
  children.clear();
  for (auto v = ct::head[u]; ~v; v = ct::next[v]) {
    children.push_back(v);
  }
  for (auto c = gr::head[u]; ~c; c = gr::next[c]) {
    Mod ways{1};
    for (auto &&v : children) {
      auto d_u = dist(c, v);
      Mod vways{0};
      // = d_u - 2
      vways += dp[v][d_u - 1];
      if (belong[u] < belong[v]) {
        if (d_u >= 2) {
          vways += dp[v][d_u - 2];
        }
      } else {
        vways += dp[v][d_u];
      }
      ways *= vways;
    }
    dp[u][dist(u, c)] += ways;
  }
}

int main() {
  IO io;
  auto T = io.read();
  while (T--) {
    std::tie(n, m) = io.read_t<int, int>();
    std::fill(head, head + n, -1);
    for (int i = 0; i < (n - 1) << 1; i++) {
      to[i] = io.read() - 1;
    }
    for (int i = 0; i < (n - 1) << 1; i++) {
      next[i] = head[to[i ^ 1]];
      head[to[i ^ 1]] = i;
    }
    std::fill(visit, visit + m, false);
    for (int i = 0; i < n; i++) {
      visit[belong[i] = io.read() - 1] = true;
    }
    auto solve = [&]() -> Mod {
      depth[0] = 0;
      parent[0] = -1;
      std::iota(dsu_parent, dsu_parent + n, 0);
      std::fill(leader_, leader_ + m, -1);
      std::fill(ct::head, ct::head + n, -1);
      std::fill(gr::head, gr::head + n, -1);
      if (!prepare(0)) {
        return Mod{0};
      }
      for (int u = 0; u < n; u++) {
        std::fill(dist_[u], dist_[u] + n, -1);
        dist_[u][u] = 0;
      }
      dfs(0);
      return std::accumulate(dp[0], dp[0] + n, Mod{0});
    };
    io << solve() << '\n';
  }
}


Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 0ms
memory: 3824kb

input:

10
15 2
10 5
3 5
12 5
10 9
11 7
3 8
2 4
7 1
15 14
8 13
15 6
2 1
4 8
11 15
1 1 1 1 2 1 1 1 2 2 1 2 1 1 1
15 3
8 11
12 8
1 3
13 15
5 9
10 13
6 12
14 4
4 9
15 5
11 10
2 14
7 2
6 3
3 2 3 2 2 3 2 1 2 1 1 3 1 2 1
15 5
1 7
5 2
11 9
6 8
13 3
14 12
3 1
8 9
5 10
10 11
5 1
12 13
10 15
11 4
3 3 3 2 3 2 1 2 2 2 ...

output:

5
0
4
0
5
2
1
0
15
0

result:

wrong answer 1st numbers differ - expected: '11', found: '5'