QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#263624 | #5421. Factories Once More | KhNURE_KIVI | RE | 0ms | 0kb | C++20 | 15.4kb | 2023-11-24 23:55:41 | 2023-11-24 23:55:43 |
answer
//#pragma GCC optimize("Ofast", "unroll-loops")
//#pragma GCC target("sse", "sse2", "sse3", "ssse3", "sse4")
#ifdef LOCAL
#include <iostream>
#include <cmath>
#include <algorithm>
#include <stdio.h>
#include <cstdint>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <bitset>
#include <map>
#include <queue>
#include <ctime>
#include <stack>
#include <set>
#include <list>
#include <random>
#include <deque>
#include <functional>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <complex>
#include <numeric>
#include <cassert>
#include <array>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#else
#include <bits/stdc++.h>
#endif
#define all(a) a.begin(),a.end()
#define len(a) (int)(a.size())
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
#define fi first
#define se second
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;
template<typename T>
inline bool umin(T &a, T b) {
if (b < a) {
a = b;
return true;
}
return false;
}
template<typename T>
inline bool umax(T &a, T b) {
if (a < b) {
a = b;
return true;
}
return false;
}
#ifdef LOCAL
#define D for (bool _FLAG = true; _FLAG; _FLAG = false)
#define LOG(...) print(#__VA_ARGS__" ::", __VA_ARGS__) << endl
template <class ...Ts> auto &print(Ts ...ts) { return ((cerr << ts << " "), ...); }
#else
#define D while (false)
#define LOG(...)
#endif // LOCAL
const int max_n = 1e5 + 11, inf = 1000111222;
//struct point {
// int x, y;
//};
//
//inline bool cmp (point a, point b) {
// return a.x < b.x || a.x == b.x && a.y < b.y;
//}
//
//inline bool cw (point a, point b, point c) {
// return a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y) < 0;
//}
//
//inline bool ccw (point a, point b, point c) {
// return a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y) > 0;
//}
//
//void convex_hull (vector<point> & a) {
// if (len(a) <= 1) {
// return;
// }
// sort (a.begin(), a.end(), cmp);
// point p1 = a[0], p2 = a.back();
// vector<point> up = {p1}, down = {p1};
// for (int i = 1; i < len(a); i++) {
// if (i == len(a) - 1 || cw (p1, a[i], p2)) {
// while (len(up) >= 2 && !cw(up[len(up) - 2], up.back(), a[i])) {
// up.pop_back();
// }
// up.pb(a[i]);
// }
// if (i == len(a) - 1 || ccw (p1, a[i], p2)) {
// while (len(down) >= 2 && !ccw(down[len(down) - 2], down.back(), a[i])) {
// down.pop_back();
// }
// down.pb(a[i]);
// }
// }
// a = up;
//// for (int i = len(down) - 2; i > 0; i--) {
//// a.pb(down[i]);
//// }
//}
//
//
//
//int dp[max_n][max_n], n, k, cnt[max_n];
//
//ll area(const vector<point>& fig) {
// ll res = 0;
// for (unsigned i = 0; i < fig.size(); i++) {
// point p = i ? fig[i - 1] : fig.back();
// point q = fig[i];
// res += (p.x - q.x) * (p.y + q.y);
// }
// return abs(res);
//}
//inline void dfs (int v, int p = -1) {
// cnt[v] = 1;
// for (auto [to, len] : edge[v]) {
// if (to == p) {
// continue;
// }
// dfs(to, v);
// for (int i1 = k; i1 >= 0; i1--) {
// for (int j = min(cnt[to], k - i1); j >= 0; j--) {
// umax(dp[v][i1 + j], dp[to][j] + dp[v][i1] + len * j * (k - j));
// if (i1 + j + 1 <= k) {
// umax(dp[v][i1 + j + 1], dp[to][j] + dp[v][i1] + len * j * (k - j));
// }
// }
// }
// cnt[v] += cnt[to];
// }
// vector <point> have;
//// LOG(cnt[v]);
// for (int i = 0; i <= min(k, cnt[v]); i++) {
// have.pb(point{i, dp[v][i]});
//// LOG(i, dp[v][i]);
// }
// auto h = have;
// convex_hull(h);
//// LOG(len(have), len(h), area(have), area(h));
//// assert(area(have) == area(h));
// if (area(have) != area(h)) {
// exit(3);
// }
//}
struct dsu {
public:
int n;
vector <int> p, cnt;
inline void make_set (int v) {
p[v] = v;
}
dsu (int n) : n(n) {
p.resize(n);
cnt.assign(n, 1);
for (int i = 0; i < n; i++) {
make_set(i);
}
}
inline int get (int v) {
if (p[v] == v) return v;
return p[v] = get(p[v]); /// compressing path
}
inline bool unite (int a, int b) {
a = get(a);
b = get(b);
if (a == b) return false;
if (cnt[a] > cnt[b]) {
swap(a, b);
}
p[a] = b;
cnt[b] += cnt[a];
return true;
}
};
const int debug = 0;
vector <pii> edge[max_n];
mt19937 rng(228);
template<typename T = int>
inline T randll(T l = INT_MIN, T r = INT_MAX) {
return uniform_int_distribution<T>(l, r)(rng);
}
inline ld randld(ld l = INT_MIN, ld r = INT_MAX) {
return uniform_real_distribution<ld>(l, r)(rng);
}
vector <int> used;
inline int dfs (int v, int ¢er, int sz, int p = -1) {
int cnt = 1;
for (auto &i : edge[v]) {
if (i.first != p && !used[i.first]) {
cnt += dfs(i.first, center, sz, v);
}
}
if (center == -1 && cnt + cnt > sz)
center = v;
return cnt;
}
struct point {
ll x, y;
point operator + (const point & p) const {
return point{x + p.x, y + p.y};
}
point operator - (const point & p) const {
return point{x - p.x, y - p.y};
}
ll cross(const point & p) const {
return x * p.y - y * p.x;
}
};
void reorder_polygon(vector<point> & P){
size_t pos = 0;
for(size_t i = 1; i < P.size(); i++){
if(P[i].y < P[pos].y || (P[i].y == P[pos].y && P[i].x < P[pos].x))
pos = i;
}
rotate(P.begin(), P.begin() + pos, P.end());
}
vector<point> minkowski(vector<point> P, vector<point> Q){
// the first vertex must be the lowest
reorder_polygon(P);
reorder_polygon(Q);
// we must ensure cyclic indexing
P.push_back(P[0]);
P.push_back(P[1]);
Q.push_back(Q[0]);
Q.push_back(Q[1]);
// main part
vector<point> result;
size_t i = 0, j = 0;
while(i < P.size() - 2 || j < Q.size() - 2){
// LOG(i, j);
result.push_back(P[i] + Q[j]);
auto cross = (P[i + 1] - P[i]).cross(Q[j + 1] - Q[j]);
if(cross >= 0 && i < P.size() - 2)
++i;
if(cross <= 0 && j < Q.size() - 2)
++j;
}
return result;
}
//vector <ll> dp[max_n];
int n, k;
vector <int> GG;
//inline void convolve (int a, int b) {
// vector <point> A(len(dp[a])), B(len(dp[b]));
// for (int i = 0; i < len(dp[a]); i++) {
// A[i] = point{ dp[a][i], i};
// }
// for (int i = 0; i < len(dp[b]); i++) {
// B[i] = point{dp[b][i], i};
// }
// auto res = minkowski(A, B);
// dp[a].resize(min(k + 1, len(A) + len(B) - 1));
// int j = 1;
// ll last = 0, val = 0;
// if (res[0].y != 0 || res[0].x != 0) {
// exit(47);
// }
//// LOG("here");
//// for (auto &i : res) {
//// LOG(i.x, i.y);
//// }
// for (int i = 1; i < len(res); i++) {
//// LOG(i);
//// LOG(res[i].y);
// while (j < res[i].y && j < len(dp[a])) {
// dp[a][j] = (j - last) * (res[i].x - val) + val * (res[i].y - last);
// if (dp[a][j] % (res[i].y - last) != 0) {
// exit(48);
// }
// dp[a][j] /= (res[i].y - last);
//// LOG(j, dp[a][j]);
// ++j;
// }
// if (j < len(dp[a])) {
// if (j != res[i].y) {
// exit(49);
// }
// dp[a][j] = res[i].x;
// ++j;
// }
// if (last > res[i].y) {
// break;
// }
// last = res[i].y;
// val = res[i].x;
// }
// LOG(len(res));
// dp[b].clear();
//// LOG(j, len(dp[a]));
// if (j != len(dp[a])) {
// exit(50);
// }
//// for (auto &kk : dp[a]) {
//// LOG(kk);
//// }
//}
//inline int calc (int l, int r) {
// if (l == r) {
// return GG[r];
// }
// int x = (l + r) >> 1;
// int L = calc(l, x);
// int R = calc(x + 1, r);
// convolve(R, L);
// return R;
//}
int cnt[max_n];
mt19937 generator;
struct treap {
int sz, value, prior;
treap *left, *right;
ll push_k,push_b;
treap(int v) {
value = v;
sz = 1;
prior = generator();
left = NULL;
right = NULL;
push_k=0;
push_b=0;
}
};
treap* make_treap_leaf_copy(treap* v)
{
assert(v->left==0);
assert(v->right==0);
treap *nt = new treap(v->value);
return nt;
}
int get_size(treap *t) {
if (t == NULL) {
return 0;
}
return t->sz;
}
void update(treap *&t) {
if (t == NULL) {
return;
}
t->sz = 1 + get_size(t->left) + get_size(t->right);
}
void add_push(treap* t,ll k,ll b)
{
assert(t!=0);
t->value+=get_size(t->left)*k+b;
t->push_k+=k;
t->push_b+=b;
}
void make_push(treap *t)
{
// assert(t!=0);
if (!t) return;
if (t->left){
add_push(t->left,t->push_k,t->push_b);
}
if (t->right){
add_push(t->right,t->push_k,t->push_b+(get_size(t->left)+1)*(t->push_k));
}
t->push_k = 0;
t->push_b = 0;
}
treap *merge(treap *t1, treap *t2) {
if (t1 == NULL) {
return t2;
}
if (t2 == NULL) {
return t1;
}
make_push(t1);
make_push(t2);
if (t1->prior <= t2->prior) {
t2->left = merge(t1, t2->left);
update(t2);
return t2;
} else {
t1->right = merge(t1->right, t2);
update(t1);
return t1;
}
}
void split_size(treap *our, int sz, treap *&l, treap *&r) {
if (our == NULL) {
l = NULL;
r = NULL;
return;
}
make_push(our);
if (get_size(our->left) + 1 == sz) {
l = our;
r = our->right;
l->right = NULL;
} else {
if (get_size(our->left) >= sz) {
r = our;
split_size(r->left, sz, l, r->left);
} else {
l = our;
split_size(l->right, sz - get_size(l->left) - 1, l->right, r);
}
}
update(l);
make_push(l);
update(r);
make_push(r);
}
//void add(treap *&t, int x) {
// treap *q = new treap(x);
// t = merge(t, q);
//}
/// >= goes l, < goes r
void explicit_split(treap *our, int key, treap *&l, treap *&r) {
if (our == NULL) {
l = NULL;
r = NULL;
return;
}
make_push(our);
if (our->value < key) {
r = our;
explicit_split(r->left, key, l, r->left);
} else {
l = our;
explicit_split(l->right, key, l->right, r);
}
update(l);
update(r);
make_push(l);
make_push(r);
}
typedef treap* barik_set;
void add_element(barik_set&t, int x) {
treap *nt = new treap(x);
treap *buf1, *buf2;
explicit_split(t, x, buf1, buf2);
t = merge(buf1, merge(nt, buf2));
}
//void explicit_erase(treap *&t, int x) {
// treap *buf1, *buf2, *buf3, *buf4;
// split(t, x, buf1, buf2);
// split(buf2, x + 1, buf3, buf4);
// t = merge(buf1, buf4);
//}
void do_naive_dfs_merge(barik_set &A,barik_set &B)
{
if (!A){
return;
}
make_push(A);
add_element(B,A->value);
do_naive_dfs_merge(A->left,B);
do_naive_dfs_merge(A->right,B);
}
barik_set merge_two_sets(barik_set A,barik_set B)
{
if (get_size(A)<get_size(B)){
do_naive_dfs_merge(A,B);
return B;
}
else{
do_naive_dfs_merge(B,A);
return A;
}
}
void barik_push_kx_plus_b(barik_set &A,ll k,ll b)
{
add_push(A,k,b);
}
void shift_right(barik_set& A)
{
treap *buf1, *buf2;
explicit_split(A, 0, buf1, buf2);
// if (buf2!=0){
// treap* t1,*t2;
// split_size(buf2,1,t1,t2);
// buf2=t2;
// }
{
// assert(buf1!=0);
treap* t1 = nullptr,*t2;
add_element(t1, 0);
buf1 = merge(buf1, t1);
}
A = merge(buf1, buf2);
}
ll do_naive_dfs_sum(barik_set &A,int& k)
{
if (!A || !k){
return 0ll;
}
make_push(A);
ll res=0;
res+=do_naive_dfs_sum(A->left,k);
if (k!=0){
res+=A->value;
k--;
}
res+=do_naive_dfs_sum(A->right,k);
return res;
}
ll sum_first_k(barik_set A,int k)
{
return do_naive_dfs_sum(A,k);
}
barik_set dp[max_n];
void print_into_cerr_dfs(barik_set A)
{
// return;
if (!A){
return;
}
make_push(A);
print_into_cerr_dfs(A->left);
cerr<<A->value<<" ";
print_into_cerr_dfs(A->right);
}
void print_into_cerr(barik_set A)
{
// return;
cerr<<"barik set :: ";
print_into_cerr_dfs(A);
cerr<<"\n";
}
inline void dfs (int v, int p = -1) {
LOG(v);
cnt[v] = 1;
for (auto [to, len] : edge[v]) {
if (to == p) {
continue;
}
dfs(to, v);
LOG(to, v);
// for (int j = 0; j < len(dp[to]); j++) {
// dp[to][j] += len * 1ll * j * (k - j);
// }
ll B = k - 1;
ll K = -2;
// dp[to].add(K, B);
// LOG(K * len, B * len);
barik_push_kx_plus_b(dp[to], K * len, B * len);
// gg.pb(to);
cnt[v] += cnt[to];
// dp[v].merge(dp[to]);
// LOG(to, v, "finish 2");
LOG("v");
print_into_cerr(dp[v]);
LOG("to");
print_into_cerr(dp[to]);
dp[v] = merge_two_sets(dp[v], dp[to]);
LOG("merge");
print_into_cerr(dp[v]);
// LOG(to, v, "finish");
}
if (cnt[v] == 1) {
// LOG("here");
add_element(dp[v], 0);
// LOG("here 2");
}
else {
// dp[v].shift();
// LOG("shift");
shift_right(dp[v]);
// LOG("shift 2");
}
LOG(v, "start");
print_into_cerr(dp[v]);
}
int main() {
freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
m1:
n = 1e4;
for (int i = 0; i <= n; i++) {
edge[i].clear();
// for (int j = 0; j <= n; j++) {
// dp[i][j] = 0;
// }
}
k = n;
if (!debug) {
cin >> n >> k;
}
used.resize(n);
dsu t(n);
LOG("here");
for (int i = 1, u, v, w; i < n; i++) {
if (!debug) {
cin >> u >> v >> w;
--u, --v;
}
else {
// u = i, v = i - 1;
w = randll(1, 100);
u = randll(0, n - 1);
v = randll(0, n - 1);
while (!t.unite(u, v)) {
u = randll(0, n - 1);
v = randll(0, n - 1);
}
}
edge[u].pb({v, w});
edge[v].pb({u, w});
}
dfs(0);
ll ans = 0;
for (int i = 0; i < k; i++) {
/// add a[i]
}
ans = sum_first_k(dp[0], k);
// ans /= 2;
cout << ans << '\n';
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Runtime Error
input:
6 3 1 2 3 2 3 2 2 4 1 1 5 2 5 6 3