QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#380643#5313. Please Save Pigelandkevinyang#RE 3ms34136kbC++174.1kb2024-04-07 07:34:282024-04-07 07:34:29

Judging History

你现在查看的是最新测评结果

  • [2024-04-07 07:34:29]
  • 评测
  • 测评结果:RE
  • 用时:3ms
  • 内存:34136kb
  • [2024-04-07 07:34:28]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mxn = 500005;
vector<vector<pair<int,int>>>adj(mxn);
vector<int>sz(mxn);
vector<int>vis(mxn);
vector<int>bad(mxn);
vector<int>ans(mxn);
vector<int>sum(mxn);
void dfs(int u, int p){
	sz[u] = 1;
	for(auto [w,nxt] : adj[u]){
		if(nxt==p||vis[nxt])continue;
		dfs(nxt,u);
		sz[u]+=sz[nxt];
	}
}
int gcd(int a, int b){
	a = abs(a);
	b = abs(b);
	if(b==0)return a;
	return gcd(b,a%b);
}
pair<int,int> add(pair<int,int>a, pair<int,int>b){
	return make_pair(a.first+b.first,a.second+b.second);
}
int cent = 0;
int n,k;
void getcentroid(int u, int p){
	bool f = true;
	for(auto [w,nxt]: adj[u]){
		if(nxt==p||vis[nxt])continue;
		getcentroid(nxt,u);
		if(sz[nxt]>n/2)f = false;
	}
	if(n-sz[u]>n/2)f = false;
	if(f)cent = u;
}
pair<int,int> op(pair<int,int> a, pair<int,int>b){
	if(a.first < -(int)1e17)return b;
	if(b.first < -(int)1e17)return a;
	return {a.first,gcd(gcd(a.second,b.second),a.first-b.first)};
}
pair<int,int> dfs1(int u, int p){
	pair<int,int>pr = {-(int)1e18,0};
	if(bad[u]){
		pr.first = 0;
	}
	for(auto [w,nxt]: adj[u]){
		if(nxt==p||vis[nxt])continue;
		pair<int,int>p2 = dfs1(nxt,u);
		p2.first += w;
		pr = op(pr,p2);
	}
	return pr;
}
void dfs2(int u, int p, int d, int g){
	ans[u] = gcd(ans[u],d);
	ans[u] = gcd(ans[u],g);
	for(auto [w,nxt] : adj[u]){
		if(vis[nxt] || nxt==p)continue;
		dfs2(nxt,u,d+w,g);
	}
}
pair<int,int> dfs3(int u, int p){
	int s = 0;
	int cnt = 0;
	for(auto [w,nxt]: adj[u]){
		if(nxt==p||vis[nxt])continue;
		auto [s2,cnt2] = dfs3(nxt,u);
		s+=s2 + w*cnt2;
		cnt+=cnt2;
	}
	if(bad[u])cnt++;
	return make_pair(s,cnt);
}
void dfs4(int u, int p, int s, int cnt){
	sum[u]+=s;
	for(auto [w,nxt]: adj[u]){
		if(nxt==p||vis[nxt])continue;
		dfs4(nxt,u,s+cnt*w,cnt);
	}
}
void centroid(int x){
	dfs(x,0);
	n = sz[x];
	getcentroid(x,0);
	vis[cent] = true;
	vector<int>nodes;
	nodes.push_back(0);
	vector<int>weights;
	weights.push_back(0);
	int N = 0;
	for(auto [w,nxt]: adj[cent]){
		if(vis[nxt])continue;
		nodes.push_back(nxt);
		weights.push_back(w);
		N++;
	}
	//cout << cent << '\n';
	vector<pair<int,int>>pre(N+1,make_pair(-(int)1e18,0LL));
	vector<pair<int,int>>suf(N+2,make_pair(-(int)1e18,0LL));
	vector<pair<int,int>>val(N+1,make_pair(-(int)1e18,0LL));
	if(bad[cent]){
		pre[0] = {0,0};
		suf[N+1] = {0,0};
	}
	for(int i = 1; i<=N; i++){
		val[i] = dfs1(nodes[i],cent);
		val[i].first+=weights[i];
		//cout << "nodes " << nodes[i] << ' ' << val[i].first << ' ' << val[i].second << '\n';
		if(val[i].first > -(int)1e17){
			ans[cent] = gcd(ans[cent],val[i].first);
			ans[cent] = gcd(ans[cent],val[i].second);
		}
		
	}
	for(int i = 1; i<=N; i++){
		pre[i] = op(pre[i-1],val[i]);
	}
	for(int i = N; i>=1; i--){
		suf[i] = op(suf[i+1],val[i]);
	}
	for(int i = 1; i<=N; i++){
		pair<int,int>p = op(pre[i-1],suf[i+1]);
		if(p.first < -(int)1e17)continue;
		//cout << nodes[i] << ' ' << p.first << ' ' <<  weights[i] << '\n';
		dfs2(nodes[i],cent,p.first+weights[i],p.second);
	}
	for(int i = 0; i<=N; i++){
		pre[i] = {0,0};
		val[i] = {0,0};
		suf[i] = suf[i+1] = {0,0};
	}
	for(int i = 1; i<=N; i++){
		auto [s, d] = dfs3(nodes[i],cent);
		s+=d*weights[i];
		sum[cent] += s;
		val[i] = {s,d};
	}
	for(int i = 1; i<=N; i++){
		pre[i] = add(pre[i-1],val[i]);
	}
	for(int i = N; i>=1; i--){
		suf[i] = add(suf[i+1],val[i]);
	}
	for(int i = 1; i<=N; i++){
		auto [s,cnt] = add(pre[i-1],suf[i+1]);
		if(bad[cent]){
			cnt++;
		}
		dfs4(nodes[i],cent,s+cnt*weights[i],cnt);
	}

	for(auto [w,nxt] : adj[cent]){
		if(vis[nxt])continue;
		centroid(nxt);
	}
}
signed main(){
	cin.tie(nullptr)->sync_with_stdio(false);
	cin >> n >> k;
	int N = n;
	for(int i = 1; i<=k; i++){
		int x;
		cin >> x;
		bad[x] = 1;
	}
	for(int i = 1; i<n; i++){
		int x,y,w;
		cin >> x >> y >> w;
		adj[x].push_back({w,y});
		adj[y].push_back({w,x});
	}
	centroid(1);
	int mn = (int)1e18;
	for(int i = 1; i<=N; i++){
		mn = min(mn,sum[i]/ans[i]*2);
	}
	cout << mn << '\n';
	//cout << gcd(-69,6) << '\n';
	return 0;
}

詳細信息

Test #1:

score: 100
Accepted
time: 0ms
memory: 34120kb

input:

5 3
3 4 5
1 2 2
2 3 4
2 5 4
3 4 6

output:

8

result:

ok 1 number(s): "8"

Test #2:

score: 0
Accepted
time: 3ms
memory: 34136kb

input:

10 3
1 7 10
7 6 3
1 8 3
3 6 3
8 6 2
4 1 1
10 6 4
2 8 3
9 10 3
5 10 3

output:

24

result:

ok 1 number(s): "24"

Test #3:

score: -100
Runtime Error

input:

1 1
1

output:


result: