QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#576039#7895. Graph Partitioning 2user10086WA 1ms9000kbC++173.1kb2024-09-19 18:03:512024-09-19 18:03:51

Judging History

你现在查看的是最新测评结果

  • [2024-09-19 18:03:51]
  • 评测
  • 测评结果:WA
  • 用时:1ms
  • 内存:9000kb
  • [2024-09-19 18:03:51]
  • 提交

answer

#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 1e5 + 10, QN = 320, MOD = 998244353;

int n, k;
vector<int> g[N];

void add(int& x, int y)
{
	(x += y) %= MOD;
}

namespace Small
{
	int dp[N][QN], tmp[QN], sz[N];

	void dfs(int x, int fa)
	{
		sz[x] = 1, dp[x][1] = 1;
		if (k == 1) dp[x][0] = 1;
		for (int y : g[x])
		{
			if (y == fa) continue;
			dfs(y, x);
			
			for (int i = 0; i <= min(k, sz[x] + sz[y]); i++) tmp[i] = 0;
			for (int i = 1; i <= min(k - 1, sz[x]); i++)
				for (int j = 0; j <= min(k - 1, sz[y]) && i + j <= k + 1; j++)
				{
					if (i + j <= k) add(tmp[i + j], dp[x][i] * dp[y][j]);
					if (i + j == k || i + j == k + 1) add(tmp[0], dp[x][i] * dp[y][j]);
				}
			add(tmp[k], dp[x][k] * dp[y][0]), add(tmp[0], dp[x][k] * dp[y][1]);
			if (k != 1) add(tmp[0], dp[x][1] * dp[y][k]);
			add(tmp[0], dp[x][0] * dp[y][0]);
			
			for (int i = 0; i <= min(k, sz[x] + sz[y]); i++) dp[x][i] = tmp[i];
			sz[x] += sz[y];
//			printf("%lld -> %lld\n", y, x);
//		for (int i = 0; i <= min(k, sz[x]); i++) printf("dp(%lld, %lld) = %lld\n", x, i, dp[x][i]);
		}
	}
	
	int solve()
	{
		for (int i = 1; i <= n; i++)
			for (int j = 0; j <= k; j++)
				dp[i][j] = 0;
		dfs(1, 0);
		return dp[1][0];
	}
}

namespace Large
{
	int dp[N][QN][2], tmp[N][2], sz[N];
	
	void dfs(int x, int fa)
	{
		sz[x] = 1, dp[x][0][0] = 1;
		if (k == 1) dp[x][1][1] = 1;
		auto r = [&](int x, int a)
		{
			return (sz[x] - a * k) % (k + 1);
		};
		
		for (int y : g[x])
		{
			if (y == fa) continue;
			dfs(y, x);
			
			for (int i = 0; i * k <= sz[x] + sz[y]; i++) tmp[i][0] = tmp[i][1] = 0;
			for (int a = 0; a * k <= sz[x]; a++)
				for (int b = 0; b * k <= sz[y]; b++)
				{
					int r1 = r(x, a), r2 = r(y, b);
					printf("a = %lld, b = %lld, r1 = %lld, r2 = %lld\n", a, b, r1, r2);
					if (r1 + r2 < k) add(tmp[a + b][0], dp[x][a][0] * dp[y][b][1]);
					if (r1 + r2 == k + 1) add(tmp[a + b][1], dp[x][a][0] * dp[y][b][1]);
					if (r1 + r2 == k) add(tmp[a + b + 1][1], dp[x][a][0] * dp[y][b][1]);
					if (r2 == 0) add(tmp[a + b][1], dp[x][a][1] * dp[y][b][1]);
				}
			
			printf("%lld -> %lld\n", y, x);
			for (int i = 0; i * k <= sz[x] + sz[y]; i++) dp[x][i][0] = tmp[i][0], dp[x][i][1] = tmp[i][1], printf("dp(%lld, %lld, 0) = %lld, dp(%lld, %lld, 1) = %lld\n", x, i, dp[x][i][0], x, i, dp[x][i][1]);
			sz[x] += sz[y];
		}
		for (int i = 0; i * k <= sz[x]; i++) add(dp[x][i][1], dp[x][i][0]);
	}
	
	int solve()
	{
		for (int i = 1; i <= n; i++)
			for (int j = 0; j * k <= n; j++)
				dp[i][j][0] = dp[i][j][1] = 0;
		dfs(1, 0);
		int ans = 0;
		for (int i = 0; i * k <= n; i++)
			if ((n - i * k) % (k + 1) == 0) add(ans, dp[1][i][1]);
		return ans;
	}
}

void solve()
{	
	cin >> n >> k;
	for (int i = 1; i <= n; i++) g[i].clear();
	for (int i = 1, u, v; i <= n - 1; i++) cin >> u >> v, g[u].push_back(v), g[v].push_back(u);
	assert(k < QN);
	cout << Large::solve() << '\n';
}

signed main()
{
	cin.tie(0)->sync_with_stdio(0);
	
	int t; cin >> t;
	while (t--) solve();	
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 1ms
memory: 9000kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:

0
0
a = 0, b = 0, r1 = 1, r2 = 1
6 -> 4
dp(4, 0, 0) = 0, dp(4, 0, 1) = 0
dp(4, 1, 0) = 0, dp(4, 1, 1) = 1
a = 0, b = 0, r1 = 1, r2 = 2
a = 0, b = 1, r1 = 1, r2 = 0
4 -> 2
dp(2, 0, 0) = 0, dp(2, 0, 1) = 0
dp(2, 1, 0) = 1, dp(2, 1, 1) = 0
a = 0, b = 0, r1 = 1, r2 = 0
a = 0, b = 1, r1 = 1, r2 = 1
2 -> ...

result:

wrong answer 1st lines differ - expected: '2', found: '0'