QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#326946 | #7997. 树 V 图 | ftiasch | WA | 0ms | 3824kb | C++23 | 11.3kb | 2024-02-14 15:14:07 | 2024-02-14 15:14:07 |
Judging History
answer
#include <bits/stdc++.h>
#include <experimental/type_traits>
// {{{ boilerplate
template <template <typename...> class Template, typename T>
struct is_specialization_of : std::false_type {};
template <template <typename...> class Template, typename... Args>
struct is_specialization_of<Template, Template<Args...>> : std::true_type {};
template <template <typename...> class Template, typename T>
inline constexpr bool is_specialization_of_v =
is_specialization_of<Template, T>::value;
#ifndef YES
#define YES "Yes"
#endif
#ifndef NO
#define NO "No"
#endif
template <typename IO> struct IOBaseT {
template <typename T = int> T read(T &&v = T{}) {
using DecayedT = std::decay_t<T>;
if constexpr (is_specialization_of_v<std::tuple, DecayedT>) {
read_t_(v, std::make_index_sequence<std::tuple_size_v<DecayedT>>());
} else if constexpr (is_vector_like<DecayedT>()) {
for (auto it = v.begin(); it != v.end(); it++) {
read(*it);
}
} else {
static_cast<IO *>(this)->template read1(std::forward<T>(v));
}
return v;
}
template <typename T> IOBaseT &operator<<(const T &o) {
if constexpr (std::is_same_v<bool, T>) {
return static_cast<IO *>(this)->write1(o ? YES : NO), *this;
} else if constexpr (is_vector_like<T>()) {
bool first = true;
for (auto it = o.begin(); it != o.end(); it++) {
if (first) {
first = false;
} else {
static_cast<IO *>(this)->template write1(' ');
}
static_cast<IO *>(this)->template write1(*it);
}
return *this;
} else {
return static_cast<IO *>(this)->template write1(o), *this;
}
}
// helper
template <typename T = int> std::vector<T> read_v(int n) {
return read(std::vector<T>(n));
}
template <typename... Ts> std::tuple<Ts...> read_t() {
return read(std::tuple<Ts...>{});
}
private:
template <typename T> using has_begin_t = decltype(std::declval<T>().begin());
template <typename T> static constexpr bool is_vector_like() {
return !std::is_same_v<T, std::string> &&
std::experimental::is_detected_v<has_begin_t, T>;
}
template <typename Tuple, std::size_t... Index>
void read_t_(Tuple &t, std::index_sequence<Index...>) {
(..., (std::get<Index>(t) = read(std::tuple_element_t<Index, Tuple>{})));
}
};
struct IO : public IOBaseT<IO> {
friend class IOBaseT<IO>;
explicit IO(bool sync = false) {
if (!sync) {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
}
}
private:
template <typename T> void read1(T &&v) { std::cin >> v; }
template <typename T> void write1(T &&v) { std::cout << v; }
};
template <typename T, typename N = uint64_t>
static constexpr T binpow(T a, N n) {
static_assert(std::is_integral_v<N>);
auto result = T::mul_id();
while (n) {
if (n & 1) {
result *= a;
}
a *= a;
n >>= 1;
}
return result;
}
template <class T> struct Singleton {
static_assert(std::is_default_constructible_v<T>);
static T &instance() {
static T store;
return store;
}
};
template <typename T> static inline T &singleton() {
return Singleton<T>::instance();
}
namespace mod {
template <typename M> struct MultiplierT {};
template <> struct MultiplierT<uint32_t> {
using M2 = uint64_t;
static constexpr int LOG_M_BITS = 5;
static constexpr M2 mul_hi(M2 x, M2 y) {
return static_cast<__uint128_t>(x) * static_cast<__uint128_t>(y) >> 64;
}
};
template <> struct MultiplierT<uint64_t> {
using M2 = __uint128_t;
static constexpr int LOG_M_BITS = 6;
static constexpr M2 mul_hi(M2 x, M2 y) {
M2 x_lo = x & UINT64_MAX;
M2 x_hi = x >> 64;
M2 y_lo = y & UINT64_MAX;
M2 y_hi = y >> 64;
M2 lo_lo = x_lo * y_lo;
M2 hi_lo = x_hi * y_lo;
M2 lo_hi = x_lo * y_hi;
M2 cy = (hi_lo & UINT64_MAX) + (lo_hi & UINT64_MAX) + (lo_lo >> 64);
return x_hi * y_hi + (hi_lo >> 64) + (lo_hi >> 64) + (cy >> 64);
}
};
template <typename M> using m2_t = typename MultiplierT<M>::M2;
template <typename Mod> struct ModWrapperT {
using M = typename Mod::M;
using M2 = m2_t<M>;
private:
// traits
template <typename T> using has_static_mod_t = decltype(T::MOD);
static constexpr bool has_static_mod =
std::experimental::is_detected_v<has_static_mod_t, Mod>;
// Montgomery performs arith in the Montgomery domain
template <typename T> using has_wrap_t = decltype(T::wrap(std::declval<M>()));
static constexpr bool has_wrap =
std::experimental::is_detected_v<has_wrap_t, Mod>;
public:
static void set_mod(M mod) {
if constexpr (!has_static_mod) {
singleton<Mod>().set_mod(mod);
}
}
static constexpr M mod() {
if constexpr (has_static_mod) {
return Mod::MOD;
} else {
return singleton<Mod>().get_mod();
}
}
static constexpr ModWrapperT mul_id() { return ModWrapperT::normalize(1); }
static constexpr ModWrapperT neg_id(uint64_t n) {
return (n & 1) ? -ModWrapperT{1} : ModWrapperT{1};
}
static constexpr ModWrapperT normalize(M2 x) {
return ModWrapperT{static_cast<M>(x % mod())};
}
constexpr ModWrapperT() : x{construct(0)} {}
template <typename T = M>
explicit constexpr ModWrapperT(T x_ = 0) : x{construct(static_cast<M>(x_))} {
static_assert(std::numeric_limits<T>::digits <=
std::numeric_limits<M>::digits);
}
constexpr M get() const {
if constexpr (has_wrap) {
return Mod::unwrap(x);
} else {
return x;
}
}
constexpr bool operator==(const ModWrapperT &other) const {
return get() == other.get();
}
constexpr ModWrapperT &operator+=(const ModWrapperT &other) {
if constexpr (has_wrap) {
Mod::add(x, other.x);
} else {
x += other.x;
if (x >= mod()) {
x -= mod();
}
}
return *this;
}
constexpr ModWrapperT &operator-=(const ModWrapperT &other) {
if constexpr (has_wrap) {
Mod::sub(x, other.x);
} else {
x += mod() - other.x;
if (x >= mod()) {
x -= mod();
}
}
return *this;
}
constexpr ModWrapperT operator*=(const ModWrapperT &other) {
auto p = static_cast<M2>(x) * static_cast<M2>(other.x);
if constexpr (has_static_mod) {
x = Mod::reduce(p);
} else {
x = singleton<Mod>().reduce(p);
}
return *this;
}
constexpr ModWrapperT operator/=(const ModWrapperT &other) {
return *this *= other.inv();
}
constexpr ModWrapperT inv() const { return binpow(*this, mod() - 2); }
// helper arith
constexpr bool operator!=(const ModWrapperT &other) const {
return !(*this == other);
}
constexpr ModWrapperT operator+(const ModWrapperT &other) const {
ModWrapperT copy = *this;
return copy += other;
}
constexpr ModWrapperT operator-() const {
ModWrapperT copy{0};
copy -= *this;
return copy;
}
constexpr ModWrapperT operator-(const ModWrapperT &other) const {
ModWrapperT copy = *this;
return copy -= other;
}
constexpr ModWrapperT operator*(const ModWrapperT &other) const {
ModWrapperT copy = *this;
return copy *= other;
}
constexpr ModWrapperT operator/(const ModWrapperT &other) const {
ModWrapperT copy = *this;
return copy /= other;
}
private:
static constexpr M construct(M x) {
if constexpr (has_wrap) {
return Mod::wrap(x);
} else {
return x;
}
}
M x;
};
} // namespace mod
namespace std {
template <typename Mod>
ostream &operator<<(ostream &out, const mod::ModWrapperT<Mod> &w) {
return out << w.get();
}
} // namespace std
namespace mod {
template <typename M_, M_ MOD_> struct ModBaseT {
using M = M_;
static constexpr M MOD = MOD_;
static_assert((MOD - 1) <= (std::numeric_limits<M_>::max() >> 1));
private:
using M2 = m2_t<M>;
public:
static constexpr M reduce(M2 x) { return x % MOD; }
};
template <uint64_t M> using Mod64T = ModWrapperT<ModBaseT<uint64_t, M>>;
template <uint32_t M> using ModT = ModWrapperT<ModBaseT<uint32_t, M>>;
} // namespace mod
using mod::Mod64T;
using mod::ModT;
namespace ranges = std::ranges;
namespace views = std::views;
// }}}
using Mod = ModT<998'244'353>;
constexpr int N = 3000;
int n, m, belong[N], head[N], to[N << 1], next[N << 1], depth[N], parent[N],
dsu_parent[N], leader_[N], dist_[N][N];
bool visit[N];
Mod dp[N][N];
namespace ct /* component tree */ {
int head[N], next[N];
}
namespace gr /* grouping */ {
int head[N], next[N];
}
int dsu_find(int u) {
if (dsu_parent[u] != u) {
dsu_parent[u] = dsu_find(dsu_parent[u]);
}
return dsu_parent[u];
}
int &leader(int u) { return leader_[belong[u]]; }
bool prepare(int u) {
if (!~leader(u)) {
leader(u) = u;
}
auto U = leader(u);
if (U != dsu_find(u)) {
return false;
}
gr::next[u] = gr::head[U];
gr::head[U] = u;
for (auto it = head[u]; ~it; it = next[it]) {
auto v = to[it];
if (v != parent[u]) {
depth[v] = u;
parent[v] = u;
if (belong[u] == belong[v]) {
dsu_parent[v] = dsu_parent[u];
}
if (!prepare(v)) {
return false;
}
auto V = leader(v);
if (U != V) {
ct::next[V] = ct::head[U];
ct::head[U] = V;
}
}
}
return true;
}
int dist(int u, int v) {
if (depth[u] > depth[v]) {
return dist(v, u);
}
auto &cache = dist_[u][v];
if (!~cache) {
cache = dist(u, parent[v]) + 1;
}
return cache;
}
std::vector<int> children;
void dfs(int u) {
for (auto v = ct::head[u]; ~v; v = ct::next[v]) {
dfs(v);
}
std::fill(dp[u], dp[u] + n, Mod{0});
children.clear();
for (auto v = ct::head[u]; ~v; v = ct::next[v]) {
children.push_back(v);
}
for (auto c = gr::head[u]; ~c; c = gr::next[c]) {
Mod ways{1};
for (auto &&v : children) {
auto d_u = dist(c, v);
Mod vways{0};
// = d_u - 2
vways += dp[v][d_u - 1];
if (belong[u] < belong[v]) {
if (d_u >= 2) {
vways += dp[v][d_u - 2];
}
} else {
vways += dp[v][d_u];
}
ways *= vways;
}
dp[u][dist(u, c)] += ways;
}
}
int main() {
IO io;
auto T = io.read();
while (T--) {
std::tie(n, m) = io.read_t<int, int>();
std::fill(head, head + n, -1);
for (int i = 0; i < (n - 1) << 1; i++) {
to[i] = io.read() - 1;
}
for (int i = 0; i < (n - 1) << 1; i++) {
next[i] = head[to[i ^ 1]];
head[to[i ^ 1]] = i;
}
std::fill(visit, visit + m, false);
for (int i = 0; i < n; i++) {
visit[belong[i] = io.read() - 1] = true;
}
auto solve = [&]() -> Mod {
depth[0] = 0;
parent[0] = -1;
std::iota(dsu_parent, dsu_parent + n, 0);
std::fill(leader_, leader_ + m, -1);
std::fill(ct::head, ct::head + n, -1);
std::fill(gr::head, gr::head + n, -1);
if (!prepare(0)) {
return Mod{0};
}
for (int u = 0; u < n; u++) {
std::fill(dist_[u], dist_[u] + n, -1);
dist_[u][u] = 0;
}
dfs(0);
return std::accumulate(dp[0], dp[0] + n, Mod{0});
};
io << solve() << '\n';
}
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 0ms
memory: 3824kb
input:
10 15 2 10 5 3 5 12 5 10 9 11 7 3 8 2 4 7 1 15 14 8 13 15 6 2 1 4 8 11 15 1 1 1 1 2 1 1 1 2 2 1 2 1 1 1 15 3 8 11 12 8 1 3 13 15 5 9 10 13 6 12 14 4 4 9 15 5 11 10 2 14 7 2 6 3 3 2 3 2 2 3 2 1 2 1 1 3 1 2 1 15 5 1 7 5 2 11 9 6 8 13 3 14 12 3 1 8 9 5 10 10 11 5 1 12 13 10 15 11 4 3 3 3 2 3 2 1 2 2 2 ...
output:
5 0 4 0 5 2 1 0 15 0
result:
wrong answer 1st numbers differ - expected: '11', found: '5'