QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#649619 | #7895. Graph Partitioning 2 | tumu1t | WA | 0ms | 3660kb | C++20 | 5.7kb | 2024-10-18 03:51:02 | 2024-10-18 03:51:03 |
Judging History
answer
#include <bits/stdc++.h>
typedef long long LL;
using std::cin, std::cout, std::endl, std::vector, std::pair;
template <typename T>
constexpr T power(T a, long long b)
{
T res = 1;
for (; b; b >>= 1, a *= a)
if (b & 1)
res *= a;
return res;
}
template <int P>
class Mint
{
public:
int x;
static int MOD;
constexpr Mint() : x() {}
constexpr Mint(long long _x) : x{norm(_x % getMOD())} {}
constexpr static void setMOD(int _MOD) { MOD = _MOD; }
constexpr static int getMOD() { return P > 0 ? P : MOD; }
constexpr int norm(int x) const { return x >= 0 && x < getMOD() ? x : (x < 0 ? x += getMOD() : x -= getMOD()); }
constexpr int val() const { return x; }
explicit constexpr operator int() const { return x; }
constexpr Mint operator-() const
{
Mint res;
res.x = norm(getMOD() - x);
return res;
}
constexpr Mint inv() const
{
assert(x != 0);
return power(*this, getMOD() - 2);
}
constexpr Mint &operator*=(Mint rhs) &
{
x = 1LL * x * rhs.x % getMOD();
return *this;
}
constexpr Mint &operator+=(Mint rhs) &
{
x = norm(x + rhs.x);
return *this;
}
constexpr Mint &operator-=(Mint rhs) &
{
x = norm(x - rhs.x);
return *this;
}
constexpr Mint &operator/=(Mint rhs) & { return *this *= rhs.inv(); }
friend constexpr Mint operator*(Mint lhs, Mint rhs)
{
Mint res = lhs;
res *= rhs;
return res;
}
friend constexpr Mint operator+(Mint lhs, Mint rhs)
{
Mint res = lhs;
res += rhs;
return res;
}
friend constexpr Mint operator-(Mint lhs, Mint rhs)
{
Mint res = lhs;
res -= rhs;
return res;
}
friend constexpr Mint operator/(Mint lhs, Mint rhs)
{
Mint res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, Mint &a)
{
long long v;
is >> v;
a = Mint(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const Mint &a) { return os << a.val(); }
friend constexpr bool operator==(Mint lhs, Mint rhs) { return lhs.val() == rhs.val(); }
friend constexpr bool operator!=(Mint lhs, Mint rhs) { return lhs.val() != rhs.val(); }
};
template <>
int Mint<0>::MOD = 998'244'353;
template <int V, int P>
constexpr Mint<P> Cinv = Mint<P>(V).inv();
constexpr int P = 998'244'353; // 可以修改这里的P 如果需要将变量设为P,需要将P变为0
using Z = Mint<P>;
// 998'244'353
void Solve1(int n, int k)
{
vector<vector<int>> G(n);
for (int i = 1; i <= n - 1; i++)
{
int x, y;
cin >> x >> y;
x -= 1;
y -= 1;
G[x].emplace_back(y);
G[y].emplace_back(x);
}
auto Dfs = [&](auto self, int at, int from) -> pair<int, vector<Z>>
{
vector<Z> Ans(k);
Ans[1] = 1;
int subTree = 1;
for (auto &v : G[at])
{
if (v == from)
continue;
else
{
const auto &[sz, SonAns] = self(self, v, at);
vector<Z> newAns(k);
for (int i = 1; i <= std::min<int>(subTree, k + 1); i++)
{
for (int j = 0; j <= std::min<int>(sz, k + 1) && i + j <= k + 1; j++)
{
newAns[i + j] += SonAns[j] * Ans[i];
}
}
Ans = std::move(newAns);
subTree += sz;
}
}
if ((int)Ans.size() >= k + 2)
Ans[0] += Ans[k + 1];
if ((int)Ans.size() >= k + 1)
Ans[0] += Ans[k];
return {subTree, Ans};
};
cout << Dfs(Dfs, 0, -1).second[0] << endl;
}
void Solve2(int n, int k)
{
vector<vector<int>> G(n);
for (int i = 1; i <= n - 1; i++)
{
int x, y;
cin >> x >> y;
x -= 1;
y -= 1;
G[x].emplace_back(y);
G[y].emplace_back(x);
}
// vector<std::map<int, Z>> Dp(n);
auto Dfs = [&](auto self, int at, int from) -> std::map<int, Z>
{
std::map<int, Z> Ans;
Ans[1] = 1;
for (auto &v : G[at])
{
if (v == from)
continue;
else
{
std::map<int, Z> newAns;
const std::map<int, Z> &sonAns = self(self, v, at);
for (auto &[i, d1] : Ans)
{
if (i != 0)
{
for (auto &[j, d2] : sonAns)
{
if (i + j > k + 1)
break;
newAns[i + j] += d1 * d2;
}
}
}
Ans = std::move(newAns);
}
}
if (auto it = Ans.find(k); it != Ans.end())
Ans[0] += (*it).second;
if (auto it = Ans.find(k + 1); it != Ans.end())
Ans[0] += (*it).second;
return Ans;
};
cout << Dfs(Dfs, 0, -1)[0] << endl;
}
bool t1 = false;
void Solve()
{
int n, k;
cin >> n >> k;
if (k < 800)
Solve1(n, k);
else
Solve2(n, k);
return;
}
int main()
{
std::ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int testcases = 1;
cin >> testcases;
for (int i = 1; i <= testcases; i++)
Solve();
}
详细
Test #1:
score: 0
Wrong Answer
time: 0ms
memory: 3660kb
input:
2 8 2 1 2 3 1 4 6 3 5 2 4 8 5 5 7 4 3 1 2 1 3 2 4
output:
0 0
result:
wrong answer 1st lines differ - expected: '2', found: '0'