QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#296051 | #4845. DFS | HoMaMaOvO (Riku Kawasaki, Masaki Nishimoto, Yui Hosaka)# | RE | 770ms | 110232kb | C++14 | 15.8kb | 2024-01-02 01:28:30 | 2024-01-02 01:28:30 |
Judging History
answer
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
using namespace std;
using Int = long long;
template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; ++i) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }
#define COLOR(s) ("\x1b[" s "m")
////////////////////////////////////////////////////////////////////////////////
template <unsigned M_> struct ModInt {
static constexpr unsigned M = M_;
unsigned x;
constexpr ModInt() : x(0U) {}
constexpr ModInt(unsigned x_) : x(x_ % M) {}
constexpr ModInt(unsigned long long x_) : x(x_ % M) {}
constexpr ModInt(int x_) : x(((x_ %= static_cast<int>(M)) < 0) ? (x_ + static_cast<int>(M)) : x_) {}
constexpr ModInt(long long x_) : x(((x_ %= static_cast<long long>(M)) < 0) ? (x_ + static_cast<long long>(M)) : x_) {}
ModInt &operator+=(const ModInt &a) { x = ((x += a.x) >= M) ? (x - M) : x; return *this; }
ModInt &operator-=(const ModInt &a) { x = ((x -= a.x) >= M) ? (x + M) : x; return *this; }
ModInt &operator*=(const ModInt &a) { x = (static_cast<unsigned long long>(x) * a.x) % M; return *this; }
ModInt &operator/=(const ModInt &a) { return (*this *= a.inv()); }
ModInt pow(long long e) const {
if (e < 0) return inv().pow(-e);
ModInt a = *this, b = 1U; for (; e; e >>= 1) { if (e & 1) b *= a; a *= a; } return b;
}
ModInt inv() const {
unsigned a = M, b = x; int y = 0, z = 1;
for (; b; ) { const unsigned q = a / b; const unsigned c = a - q * b; a = b; b = c; const int w = y - static_cast<int>(q) * z; y = z; z = w; }
assert(a == 1U); return ModInt(y);
}
ModInt operator+() const { return *this; }
ModInt operator-() const { ModInt a; a.x = x ? (M - x) : 0U; return a; }
ModInt operator+(const ModInt &a) const { return (ModInt(*this) += a); }
ModInt operator-(const ModInt &a) const { return (ModInt(*this) -= a); }
ModInt operator*(const ModInt &a) const { return (ModInt(*this) *= a); }
ModInt operator/(const ModInt &a) const { return (ModInt(*this) /= a); }
template <class T> friend ModInt operator+(T a, const ModInt &b) { return (ModInt(a) += b); }
template <class T> friend ModInt operator-(T a, const ModInt &b) { return (ModInt(a) -= b); }
template <class T> friend ModInt operator*(T a, const ModInt &b) { return (ModInt(a) *= b); }
template <class T> friend ModInt operator/(T a, const ModInt &b) { return (ModInt(a) /= b); }
explicit operator bool() const { return x; }
bool operator==(const ModInt &a) const { return (x == a.x); }
bool operator!=(const ModInt &a) const { return (x != a.x); }
friend std::ostream &operator<<(std::ostream &os, const ModInt &a) { return os << a.x; }
};
////////////////////////////////////////////////////////////////////////////////
constexpr unsigned MO = 998244353;
using Mint = ModInt<MO>;
constexpr int LIM_INV = 800'010;
Mint inv[LIM_INV], fac[LIM_INV], invFac[LIM_INV];
void prepare() {
inv[1] = 1;
for (int i = 2; i < LIM_INV; ++i) {
inv[i] = -((Mint::M / i) * inv[Mint::M % i]);
}
fac[0] = invFac[0] = 1;
for (int i = 1; i < LIM_INV; ++i) {
fac[i] = fac[i - 1] * i;
invFac[i] = invFac[i - 1] * inv[i];
}
}
Mint binom(Int n, Int k) {
if (n < 0) {
if (k >= 0) {
return ((k & 1) ? -1 : +1) * binom(-n + k - 1, k);
} else if (n - k >= 0) {
return (((n - k) & 1) ? -1 : +1) * binom(-k - 1, n - k);
} else {
return 0;
}
} else {
if (0 <= k && k <= n) {
assert(n < LIM_INV);
return fac[n] * invFac[k] * invFac[n - k];
} else {
return 0;
}
}
}
// Meldable
// 0 for null, ts[0] = T()
// chPoint(u, a): point update
// chRange(u, a, b): range update s.t. T() -> T()
// T::push(T *l, T *r)
// T::pull(const T &l, const T &r)
// T::meld(const T &t)
template <class T> struct Seg {
static constexpr int NUM_NODES = 1 << 22;
int l0, r0;
int nodesLen;
int ls[NUM_NODES], rs[NUM_NODES];
T ts[NUM_NODES];
void init(int l0_, int r0_) {
l0 = l0_;
r0 = r0_;
nodesLen = 1;
ls[0] = rs[0] = 0;
ts[0] = T();
}
int newNode() {
assert(nodesLen < NUM_NODES);
const int u = nodesLen++;
ls[u] = rs[u] = 0;
ts[u] = T();
return u;
}
void push(int u) {
ts[u].push(ls[u] ? &ts[ls[u]] : nullptr, rs[u] ? &ts[rs[u]] : nullptr);
}
void pull(int u) {
ts[u].pull(ts[ls[u]], ts[rs[u]]);
}
template <class F, class... Args>
void chPoint(int &u, int l, int r, int a, F f, Args &&... args) {
if (!u) u = newNode();
if (l + 1 == r) {
(ts[u].*f)(args...);
return;
}
const int mid = l + ((r - l) >> 1);
push(u);
(a < mid)
? chPoint(ls[u], l, mid, a, f, args...)
: chPoint(rs[u], mid, r, a, f, args...);
pull(u);
}
template <class F, class... Args>
void chPoint(int &u, int a, F f, Args &&... args) {
assert(l0 <= a); assert(a < r0);
chPoint(u, l0, r0, a, f, args...);
}
template <class F, class... Args>
void chRange(int &u, int l, int r, int a, int b, F f, Args &&... args) {
if (!u) return;
if (b <= l || r <= a) return;
if (a <= l && r <= b) {
(ts[u].*f)(args...);
return;
}
const int mid = l + ((r - l) >> 1);
push(u);
chRange(ls[u], l, mid, a, b, f, args...);
chRange(rs[u], mid, r, a, b, f, args...);
pull(u);
}
template <class F, class... Args>
void chRange(int &u, int a, int b, F f, Args &&... args) {
assert(l0 <= a); assert(a <= b); assert(b <= r0);
chRange(u, l0, r0, a, b, f, args...);
}
T get(int u, int l, int r, int a, int b) {
if (!u) return T();
if (b <= l || r <= a) return T();
if (a <= l && r <= b) return ts[u];
const int mid = l + ((r - l) >> 1);
push(u);
const T tL = get(ls[u], l, mid, a, b);
const T tR = get(rs[u], mid, r, a, b);
pull(u);
T t;
t.pull(tL, tR);
return t;
}
T get(int u, int a, int b) {
assert(l0 <= a); assert(a <= b); assert(b <= r0);
return get(u, l0, r0, a, b);
}
// Frees v.
int meld(int u, int v, int l, int r) {
if (!u) return v;
if (!v) return u;
if (l + 1 == r) {
ts[u].meld(ts[v]);
return u;
}
const int mid = l + ((r - l) >> 1);
push(u);
push(v);
ls[u] = meld(ls[u], ls[v], l, mid);
rs[u] = meld(rs[u], rs[v], mid, r);
pull(u);
return u;
}
int meld(int u, int v) {
return meld(u, v, l0, r0);
}
void print(int depth, int u, int l, int r) const {
if (!u) return;
cerr << string(2 * depth, ' ') << u << " [" << l << ", " << r << ") " << ts[u] << endl;
if (l + 1 == r) return;
const int mid = l + ((r - l) >> 1);
print(depth + 1, ls[u], l, mid);
print(depth + 1, rs[u], mid, r);
}
void print(int u) const {
if (!u) cerr << "[Seg::print] null" << endl;
print(0, u, l0, r0);
}
};
int N, R;
vector<int> C;
vector<int> A, B;
vector<pair<Int, int>> cus;
vector<int> ids;
vector<vector<int>> graph;
namespace brute {
vector<int> mns;
void dfs(int u, int p) {
mns[u] = ids[u];
for (const int v : graph[u]) if (p != v) {
dfs(v, u);
chmin(mns[u], mns[v]);
}
}
Mint ans;
map<int, Mint> solve(int u, int p) {
vector<pair<int, int>> xvs;
for (const int v : graph[u]) if (p != v) {
xvs.emplace_back(mns[v], v);
}
sort(xvs.begin(), xvs.end());
const int len = xvs.size();
map<int, Mint> tmp;
for (int j = 0; j < len; ++j) {
const int v = xvs[j].second;
const auto res = solve(v, u);
for (const auto &kv : res) {
assert(xvs[j].first <= kv.first);
int k;
for (k = 0; k < len && xvs[k].first <= kv.first; ++k) if (j != k) {
tmp[xvs[k].first] += ((k < j) ? (inv[k + 1] * inv[k + 2]) : (inv[k] * inv[k + 1])) * kv.second;
}
tmp[kv.first] += inv[k] * kv.second;
}
}
map<int, Mint> ret;
ret[ids[u]] += 1;
for (const auto &kv : tmp) {
ret[min(ids[u], kv.first)] += kv.second;
}
for (const auto &kv : ret) {
ans += kv.second * cus[kv.first].first;
}
/*
cerr<<"u = "<<u<<": xvs = "<<xvs<<endl;
cerr<<"u = "<<u<<": tmp = ";pv(tmp.begin(),tmp.end());
cerr<<"u = "<<u<<": ret = ";pv(ret.begin(),ret.end());
*/
return ret;
}
Mint run() {
mns.assign(N, -1);
dfs(R, -1);
ans = 0;
solve(R, -1);
return ans;
}
} // brute
/*
range sum -> add to somewhere
to x[0]
x[0] x[1] x[2] x[3]
f[0] xxxx|----|----|----|----
f[1] xxxx|xxxx|@@@@|@@@@|@@@@ 1/2
f[2] xxxx|xxxx|xxxx|@@@@|@@@@ 1/2
f[3] xxxx|xxxx|xxxx|xxxx|@@@@ 1/2
to x[1]
x[0] x[1] x[2] x[3]
f[0] xxxx|----|oooo|oooo|oooo 1/2
f[1] xxxx|xxxx|----|----|----
f[2] xxxx|xxxx|xxxx|@@@@|@@@@ 1/6
f[3] xxxx|xxxx|xxxx|xxxx|@@@@ 1/6
to x[2]
x[0] x[1] x[2] x[3]
f[0] xxxx|----|----|oooo|oooo 1/6
f[1] xxxx|xxxx|----|oooo|oooo 1/6
f[2] xxxx|xxxx|xxxx|----|----
f[3] xxxx|xxxx|xxxx|xxxx|@@@@ 1/12
to x[3]
x[0] x[1] x[2] x[3]
f[0] xxxx|----|----|----|oooo 1/12
f[1] xxxx|xxxx|----|----|oooo 1/12
f[2] xxxx|xxxx|xxxx|----|oooo 1/12
f[3] xxxx|xxxx|xxxx|xxxx|----
range mul
x[0] x[1] x[2] x[3]
----|----|----|----|----
1/1 1/2 1/3 1/4
*/
namespace slow {
vector<int> mns;
void dfs(int u, int p) {
mns[u] = ids[u];
for (const int v : graph[u]) if (p != v) {
dfs(v, u);
chmin(mns[u], mns[v]);
}
}
Mint ans;
map<int, Mint> solve(int u, int p) {
vector<pair<int, int>> xvs;
for (const int v : graph[u]) if (p != v) {
xvs.emplace_back(mns[v], v);
}
sort(xvs.begin(), xvs.end());
const int len = xvs.size();
map<int, Mint> sum;
vector<Mint> fs(len, 0);
for (int j = 0; j < len; ++j) {
if (j) {
// range sum
for (const auto &kv : sum) if (xvs[j].first <= kv.first) {
fs[j] += kv.second;
}
}
const int v = xvs[j].second;
const auto res = solve(v, u);
// meld
for (const auto &kv : res) {
sum[kv.first] += kv.second;
}
}
map<int, Mint> tmp;
for (int j = 0; j < len; ++j) {
if (j > 0) {
tmp[xvs[j].first] += (inv[j] * inv[j + 1]) * fs[j];
}
if (j < len - 1) {
// range sum
Mint all = 0;
for (const auto &kv : sum) if (xvs[j + 1].first <= kv.first) {
all += kv.second;
}
tmp[xvs[j].first] += (inv[j + 1] * inv[j + 2]) * (all - fs[j + 1]);
}
}
// range mul
for (int j = 0; j < len; ++j) {
for (const auto &kv : sum) if (xvs[j].first <= kv.first && (j + 1 == len || kv.first < xvs[j + 1].first)) {
tmp[kv.first] += inv[j + 1] * kv.second;
}
}
// attach u
map<int, Mint> ret;
ret[ids[u]] += 1;
for (const auto &kv : tmp) {
ret[min(ids[u], kv.first)] += kv.second;
}
for (const auto &kv : ret) {
ans += kv.second * cus[kv.first].first;
}
/*
cerr<<"u = "<<u<<": xvs = "<<xvs<<endl;
cerr<<"u = "<<u<<": tmp = ";pv(tmp.begin(),tmp.end());
cerr<<"u = "<<u<<": ret = ";pv(ret.begin(),ret.end());
*/
return ret;
}
Mint run() {
mns.assign(N, -1);
dfs(R, -1);
ans = 0;
solve(R, -1);
return ans;
}
} // slow
namespace fast {
struct Node {
Mint wt, sum, sum1;
Mint lz;
Node() : wt(0), sum(0), sum1(0), lz(1) {}
friend ostream &operator<<(ostream &os, const Node &t) {
return os << "(wt=" << t.wt << ", sum=" << t.sum << ", sum1=" << t.sum1 << ", lz=" << t.lz << ")";
}
void push(Node *l, Node *r) {
if (lz != 1) {
if (l) l->mul(lz);
if (r) r->mul(lz);
lz = 1;
}
}
void pull(const Node &l, const Node &r) {
sum = l.sum + r.sum;
sum1 = l.sum1 + r.sum1;
}
void meld(const Node &t) {
assert(wt == t.wt);
sum += t.sum;
sum1 += t.sum;
}
void mul(Mint val) {
sum *= val;
sum1 *= val;
lz *= val;
}
// leaf
void addToLeaf(Mint wt_, Mint val) {
wt = wt_;
sum += val;
sum1 += wt * val;
}
};
Seg<Node> seg;
void pointAdd(int &node, int x, Mint val) {
// cerr<<COLOR("93")<<"[pointAdd] "<<node<<" "<<x<<" "<<val<<COLOR()<<endl;
if (val) {
seg.chPoint(node, x, &Node::addToLeaf, cus[x].first, val);
}
}
vector<int> mns;
void dfs(int u, int p) {
mns[u] = ids[u];
for (const int v : graph[u]) if (p != v) {
dfs(v, u);
chmin(mns[u], mns[v]);
}
}
Mint ans;
int solve(int u, int p) {
vector<pair<int, int>> xvs;
for (const int v : graph[u]) if (p != v) {
xvs.emplace_back(mns[v], v);
}
sort(xvs.begin(), xvs.end());
const int len = xvs.size();
int sum = 0;
vector<Mint> fs(len, 0);
for (int j = 0; j < len; ++j) {
if (j) {
fs[j] = seg.get(sum, xvs[j].first, N).sum;
}
const int v = xvs[j].second;
const int res = solve(v, u);
sum = seg.meld(sum, res);
}
vector<Mint> alls(len, 0);
for (int j = 1; j < len; ++j) {
alls[j] = seg.get(sum, xvs[j].first, N).sum;
}
int tmp = sum;
for (int j = 0; j < len; ++j) {
seg.chRange(tmp, xvs[j].first, (j + 1 == len) ? N : xvs[j + 1].first, &Node::mul, inv[j + 1]);
}
for (int j = 0; j < len; ++j) {
Mint val = 0;
if (j > 0) {
val += (inv[j] * inv[j + 1]) * fs[j];
}
if (j < len - 1) {
val += (inv[j + 1] * inv[j + 2]) * (alls[j + 1] - fs[j + 1]);
}
pointAdd(tmp, xvs[j].first, val);
}
// attach u
int ret = tmp;
{
const Mint val = seg.get(ret, ids[u], N).sum;
seg.chRange(ret, ids[u], N, &Node::mul, 0);
pointAdd(ret, ids[u], val);
}
pointAdd(ret, ids[u], 1);
ans += seg.ts[ret].sum1;
// cerr<<COLOR("91")<<"u = "<<u<<": xvs = "<<xvs<<COLOR()<<endl;
// seg.print(ret);
return ret;
}
Mint run() {
mns.assign(N, -1);
dfs(R, -1);
ans = 0;
seg.init(0, N);
solve(R, -1);
return ans;
}
} // fast
int main() {
prepare();
for (int numCases; ~scanf("%d", &numCases); ) { for (int caseId = 1; caseId <= numCases; ++caseId) {
scanf("%d%d", &N, &R);
--R;
C.resize(N);
for (int u = 0; u < N; ++u) {
scanf("%d", &C[u]);
}
A.resize(N - 1);
B.resize(N - 1);
for (int i = 0; i < N - 1; ++i) {
scanf("%d%d", &A[i], &B[i]);
--A[i];
--B[i];
}
cus.resize(N);
for (int u = 0; u < N; ++u) {
cus[u] = make_pair(C[u], u);
}
sort(cus.begin(), cus.end());
ids.assign(N, -1);
for (int x = 0; x < N; ++x) {
ids[cus[x].second] = x;
}
// cerr<<"cus = "<<cus<<endl;
graph.assign(N, {});
for (int i = 0; i < N - 1; ++i) {
graph[A[i]].push_back(B[i]);
graph[B[i]].push_back(A[i]);
}
const Mint ans = fast::run();
printf("%u\n", ans.x);
#ifdef LOCAL
const Mint brt=brute::run();
cerr<<"brt = "<<brt<<endl;
#endif
}
#ifndef LOCAL
break;
#endif
}
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 7ms
memory: 78620kb
input:
4 1 1 1 3 3 3 3 4 3 1 3 2 6 1 5 2 4 1 3 6 1 2 1 6 2 3 2 4 4 5 5 1 5 4 3 2 1 1 2 1 3 3 4 3 5
output:
1 16 34 499122202
result:
ok 4 number(s): "1 16 34 499122202"
Test #2:
score: 0
Accepted
time: 54ms
memory: 80228kb
input:
7 5000 933 23306350 162661794 68618194 666430282 995855733 929210414 295740530 464135554 304211641 725090719 226242817 592655639 936895997 479520010 108891341 598601399 678169271 118406229 394867734 640888099 481066130 606481085 709600400 554804145 179044332 41718098 549318629 400214219 159098456 67...
output:
647896606 670593316 448857064 140431373 205960849 578484974 77271255
result:
ok 7 numbers
Test #3:
score: 0
Accepted
time: 48ms
memory: 80200kb
input:
7 5000 3103 370388267 486433577 320921400 202742370 718520472 895122554 359601184 337298427 146232648 830940586 826047977 229951976 919497287 67059290 843962706 777524684 196246825 682863698 122309598 743014832 349595264 964812795 660215381 963691135 376209511 759526115 822829377 916639371 802105594...
output:
371972695 888631092 915285114 181894451 144642157 385851099 551331307
result:
ok 7 numbers
Test #4:
score: 0
Accepted
time: 62ms
memory: 80220kb
input:
7 5000 4968 717470184 105172657 308383390 593830267 706026427 861034695 423461838 60718197 988253654 231757749 690694353 162215609 461907090 804341675 874001367 811223777 714324378 97578063 704527271 140108861 363348590 28177209 760573466 522321229 868341986 917525621 541050526 287840331 295369628 8...
output:
249142844 829572814 771685997 331594252 670718819 285030671 941548421
result:
ok 7 numbers
Test #5:
score: 0
Accepted
time: 752ms
memory: 110232kb
input:
6 100000 64523 754457004 816672702 108228103 727565046 575355672 557029941 374900316 667662208 247774859 241366803 327704349 984466565 870042639 983615018 950320614 473363640 338609139 877917484 271995097 819820270 213958988 431503923 943718683 511580322 791201903 944289849 290330452 148342386 88563...
output:
913802776 109864262 759204988 972877912 293869584 437197447
result:
ok 6 numbers
Test #6:
score: 0
Accepted
time: 752ms
memory: 110072kb
input:
6 100000 27972 828633097 164239555 181485176 220605565 677776542 945405871 80032153 66332854 697109172 781941604 919689741 375235146 744974712 849606202 656187811 476076039 737032058 914423438 292449741 668723961 67722543 681470786 510488095 55665910 324786541 329651259 397950166 495275796 250945521...
output:
739855553 129875796 690404165 733032567 628627455 517469818
result:
ok 6 numbers
Test #7:
score: 0
Accepted
time: 739ms
memory: 110084kb
input:
6 100000 99053 836797137 52289738 16685232 129865475 495829444 97013519 944813587 482856253 783712967 280370957 362687061 720487922 790405858 101237849 24038113 444097770 473476096 758876685 251305540 540743087 490978674 937704773 452442801 129907216 995115116 2093475 885869693 97229474 78775431 917...
output:
474661935 399436528 324886253 140783938 80304389 935874971
result:
ok 6 numbers
Test #8:
score: 0
Accepted
time: 749ms
memory: 110156kb
input:
6 100000 83118 237932086 805355107 875130961 45785348 463436298 235970586 792133756 802382318 612808476 551682021 372485130 452790514 513627049 543295307 217598321 265929567 797044227 688393222 126744189 35928058 811269903 849584261 483513038 911837564 366458567 951847531 105866839 756251409 4968014...
output:
898902290 811063805 206282848 258264260 148922868 224589860
result:
ok 6 numbers
Test #9:
score: 0
Accepted
time: 770ms
memory: 110076kb
input:
6 100000 10574 960703914 919137270 287906775 779918168 386941712 826368624 487253993 514726858 887984890 319651075 169632006 397669774 306766175 855993268 923893384 657564124 709799196 903310348 199644397 525583279 671731311 208189848 589129815 666199622 43201407 53804137 209640204 481408933 8963734...
output:
848526994 687580630 401382009 610479360 879854083 715820413
result:
ok 6 numbers
Test #10:
score: 0
Accepted
time: 748ms
memory: 110148kb
input:
6 100000 47654 766386374 205079444 260272024 972278825 717481371 349976769 283239755 15665182 300195285 987350791 13544206 466113372 892906150 988276842 194045085 558067840 335484149 509185986 186108072 274113084 111042012 993726778 61868884 324431830 214676266 481275089 711334587 593870246 66804272...
output:
765548596 532380615 903749692 242413341 967864239 381854796
result:
ok 6 numbers
Test #11:
score: -100
Runtime Error
input:
2 400000 73762 862365565 233904902 801015102 662519031 634410644 227172052 743137547 514055616 901427716 333289732 679283877 568568098 821858064 900217898 250453718 239331147 531665581 594636716 707168609 243610236 780894216 387870279 57816953 665974155 852470526 447162507 984329800 324745826 648826...