QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#576039 | #7895. Graph Partitioning 2 | user10086 | WA | 1ms | 9000kb | C++17 | 3.1kb | 2024-09-19 18:03:51 | 2024-09-19 18:03:51 |
Judging History
answer
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 10, QN = 320, MOD = 998244353;
int n, k;
vector<int> g[N];
void add(int& x, int y)
{
(x += y) %= MOD;
}
namespace Small
{
int dp[N][QN], tmp[QN], sz[N];
void dfs(int x, int fa)
{
sz[x] = 1, dp[x][1] = 1;
if (k == 1) dp[x][0] = 1;
for (int y : g[x])
{
if (y == fa) continue;
dfs(y, x);
for (int i = 0; i <= min(k, sz[x] + sz[y]); i++) tmp[i] = 0;
for (int i = 1; i <= min(k - 1, sz[x]); i++)
for (int j = 0; j <= min(k - 1, sz[y]) && i + j <= k + 1; j++)
{
if (i + j <= k) add(tmp[i + j], dp[x][i] * dp[y][j]);
if (i + j == k || i + j == k + 1) add(tmp[0], dp[x][i] * dp[y][j]);
}
add(tmp[k], dp[x][k] * dp[y][0]), add(tmp[0], dp[x][k] * dp[y][1]);
if (k != 1) add(tmp[0], dp[x][1] * dp[y][k]);
add(tmp[0], dp[x][0] * dp[y][0]);
for (int i = 0; i <= min(k, sz[x] + sz[y]); i++) dp[x][i] = tmp[i];
sz[x] += sz[y];
// printf("%lld -> %lld\n", y, x);
// for (int i = 0; i <= min(k, sz[x]); i++) printf("dp(%lld, %lld) = %lld\n", x, i, dp[x][i]);
}
}
int solve()
{
for (int i = 1; i <= n; i++)
for (int j = 0; j <= k; j++)
dp[i][j] = 0;
dfs(1, 0);
return dp[1][0];
}
}
namespace Large
{
int dp[N][QN][2], tmp[N][2], sz[N];
void dfs(int x, int fa)
{
sz[x] = 1, dp[x][0][0] = 1;
if (k == 1) dp[x][1][1] = 1;
auto r = [&](int x, int a)
{
return (sz[x] - a * k) % (k + 1);
};
for (int y : g[x])
{
if (y == fa) continue;
dfs(y, x);
for (int i = 0; i * k <= sz[x] + sz[y]; i++) tmp[i][0] = tmp[i][1] = 0;
for (int a = 0; a * k <= sz[x]; a++)
for (int b = 0; b * k <= sz[y]; b++)
{
int r1 = r(x, a), r2 = r(y, b);
printf("a = %lld, b = %lld, r1 = %lld, r2 = %lld\n", a, b, r1, r2);
if (r1 + r2 < k) add(tmp[a + b][0], dp[x][a][0] * dp[y][b][1]);
if (r1 + r2 == k + 1) add(tmp[a + b][1], dp[x][a][0] * dp[y][b][1]);
if (r1 + r2 == k) add(tmp[a + b + 1][1], dp[x][a][0] * dp[y][b][1]);
if (r2 == 0) add(tmp[a + b][1], dp[x][a][1] * dp[y][b][1]);
}
printf("%lld -> %lld\n", y, x);
for (int i = 0; i * k <= sz[x] + sz[y]; i++) dp[x][i][0] = tmp[i][0], dp[x][i][1] = tmp[i][1], printf("dp(%lld, %lld, 0) = %lld, dp(%lld, %lld, 1) = %lld\n", x, i, dp[x][i][0], x, i, dp[x][i][1]);
sz[x] += sz[y];
}
for (int i = 0; i * k <= sz[x]; i++) add(dp[x][i][1], dp[x][i][0]);
}
int solve()
{
for (int i = 1; i <= n; i++)
for (int j = 0; j * k <= n; j++)
dp[i][j][0] = dp[i][j][1] = 0;
dfs(1, 0);
int ans = 0;
for (int i = 0; i * k <= n; i++)
if ((n - i * k) % (k + 1) == 0) add(ans, dp[1][i][1]);
return ans;
}
}
void solve()
{
cin >> n >> k;
for (int i = 1; i <= n; i++) g[i].clear();
for (int i = 1, u, v; i <= n - 1; i++) cin >> u >> v, g[u].push_back(v), g[v].push_back(u);
assert(k < QN);
cout << Large::solve() << '\n';
}
signed main()
{
cin.tie(0)->sync_with_stdio(0);
int t; cin >> t;
while (t--) solve();
}
詳細信息
Test #1:
score: 0
Wrong Answer
time: 1ms
memory: 9000kb
input:
2 8 2 1 2 3 1 4 6 3 5 2 4 8 5 5 7 4 3 1 2 1 3 2 4
output:
0 0 a = 0, b = 0, r1 = 1, r2 = 1 6 -> 4 dp(4, 0, 0) = 0, dp(4, 0, 1) = 0 dp(4, 1, 0) = 0, dp(4, 1, 1) = 1 a = 0, b = 0, r1 = 1, r2 = 2 a = 0, b = 1, r1 = 1, r2 = 0 4 -> 2 dp(2, 0, 0) = 0, dp(2, 0, 1) = 0 dp(2, 1, 0) = 1, dp(2, 1, 1) = 0 a = 0, b = 0, r1 = 1, r2 = 0 a = 0, b = 1, r1 = 1, r2 = 1 2 -> ...
result:
wrong answer 1st lines differ - expected: '2', found: '0'