QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#316180 | #7895. Graph Partitioning 2 | ucup-team1055# | WA | 1ms | 3752kb | C++20 | 7.5kb | 2024-01-27 17:56:36 | 2024-01-27 17:56:37 |
Judging History
answer
#include<bits/stdc++.h>
#define rep(i,s,n) for(int i = int(s); i < int(n); i++)
#define rrep(i,s,n) for(int i = int(n) - 1; i >= int(s); i--)
#define all(v) (v).begin(), (v).end()
using ll = long long;
using ull = unsigned long long;
using ld = long double;
template<class T>
bool chmin(T &a, T b) {
if(a <= b) return false;
a = b;
return true;
}
template<class T>
bool chmax(T &a, T b) {
if(a >= b) return false;
a = b;
return true;
}
using namespace std;
template<ll m> struct modint {
using mint = modint;
ll a;
modint(ll x = 0) : a ((x % m + m) % m) {}
static constexpr ll mod (){
return m;
}
ll val() const {
return a;
}
ll& val() {
return a;
}
mint pow(ll n) const {
mint res = 1;
mint x = a;
while(n){
if (n & 1) res *= x;
x *= x;
n >>= 1;
}
return res;
}
mint inv() const {
return pow(m-2);
}
mint & operator+=(const mint rhs){
a += rhs.a;
if (a >= m) a-= m;
return *this;
}
mint & operator-=(const mint rhs){
if (a < rhs.a) a += m;
a -= rhs.a;
return *this;
}
mint & operator*=(const mint rhs){
a = a * rhs.a % m;
return *this;
}
mint & operator/=(mint rhs){
*this *= rhs.inv();
return *this;
}
friend mint operator+(const mint& lhs, const mint& rhs){
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs){
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs){
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs){
return mint(lhs) /= rhs;
}
friend bool operator==(const modint &lhs, const modint &rhs){
return lhs.a == rhs.a;
}
friend bool operator!=(const modint &lhs, const modint &rhs){
return !(lhs == rhs);
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
};
using modint998244353 = modint<998244353>;
using mint = modint998244353;
const int th = 1000;
void solve() {
int n,k;
std::cin >> n >> k;
std::vector g(n, std::vector<int>());
rep(i,0,n-1) {
int u,v;
std::cin >> u >> v;
u--; v--;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
std::vector<int> dfs_order;
std::vector<int> par(n, -1);
auto dfs = [&](auto &&self, int v) -> void {
dfs_order.emplace_back(v);
for(auto nv: g[v]) {
if(nv == par[v]) continue;
par[nv] = v;
self(self, nv);
}
};
dfs(dfs, 0);
// k is small
if(k < th && false) {
auto mul = [&](const std::vector<mint> &a, const std::vector<mint> &b) -> std::vector<mint> {
int s = a.size(), t = b.size();
int sz = std::min(k+2, s + t - 1);
std::vector<mint> ab(sz, 0);
rep(i,0,s) {
rep(j,0,t) {
if(i + j >= sz) break;
ab[i + j] += a[i] * b[j];
}
}
return ab;
};
std::vector dp(n, std::vector<mint>(2, 0));
for(auto v: dfs_order | std::views::reverse) {
dp[v][1] = 1;
for(auto nv: g[v]) {
if(nv == par[v]) continue;
dp[v] = mul(dp[v], dp[nv]);
}
if((int)dp[v].size() > k) {
dp[v][0] += dp[v][k];
}
if((int)dp[v].size() > k + 1) {
dp[v][0] += dp[v][k+1];
dp[v][k+1] = 0;
}
}
mint ans = dp[0][0];
std::cout << ans.val() << '\n';
}else{
vector<vector<int>> ikeru = g;
vector<int> siz(n);
vector<int> mada = {~0, 0};
vector<int> tansaku(n);
tansaku[0] = 1;
while(!mada.empty()){
int i = mada.back();
mada.pop_back();
if (i >= 0){
for (int j:ikeru[i]){
if (tansaku[j] == 0){
mada.push_back(~j);
mada.push_back(j);
tansaku[j] = 1;
}
}
}else{
i = ~i;
for(int j: ikeru[i]){
if (tansaku[j] == 2){
siz[i] += siz[j];
}
}
siz[i] += 1;
tansaku[i] = 2;
}
}
vector<map<pair<int,int>,mint>> e(n);
vector<map<pair<int,int>,mint>> dp(n);
fill(tansaku.begin(), tansaku.end(), 0);
mada.push_back(~0);
mada.push_back(0);
tansaku[0] = 1;
while(!mada.empty()){
int i = mada.back();
mada.pop_back();
if (i >= 0){
for (int j: ikeru[i]){
if (tansaku[j] == 0){
tansaku[j] = 1;
mada.push_back(~j);
mada.push_back(j);
}
}
}else{
i = ~i;
int mx_siz = -1;
int mx_ind = -1;
vector<int> dar;
ll num_k1 = siz[i]%k;
ll num_k0 = (siz[i] - (ll)(num_k1)*(k+1))/k;
for (int j: ikeru[i]){
if (tansaku[j] == 2){
if (chmax(mx_siz, (int)e[j].size())){
mx_ind = j;
}
dar.push_back(j);
}
}
if (mx_ind >= 0){
swap(e[i], e[mx_ind]);
}
for (int j: dar){
if (j == mx_ind) continue;
map<pair<int,int>,mint> np;
for (auto [x, c]: e[j]){
for (auto [y, d]: e[i]){
np[pair(x.first + y.first, x.second + y.second)] += c * d;
}
}
swap(e[i], np);
}
for (int j: dar){
e[j].clear();
}
e[i][pair(0, 0)] += 1;
if (num_k0 >= 0 && num_k1 >= 0){
if (num_k0 > 0){
if (e[i].find(pair(num_k0-1, num_k1)) != e[i].end()){
dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0-1, num_k1)];
}
}
if (num_k1 > 0){
if (e[i].find(pair(num_k0, num_k1-1)) != e[i].end()){
dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0, num_k1-1)];
}
}
}
e[i][pair(0, 0)] -= 1;
for (auto [x,c]: dp[i]) e[i][x] += c;
tansaku[i] = 2;
}
}
mint ans = 0;
for (auto[x,c] : dp[0]){
ans += c;
}
cout << ans.val() << '\n';
}
}
int main() {
std::cin.tie(nullptr);
std::ios::sync_with_stdio(false);
int t;
std::cin >> t;
while(t--) {
solve();
}
}
詳細信息
Test #1:
score: 0
Wrong Answer
time: 1ms
memory: 3752kb
input:
2 8 2 1 2 3 1 4 6 3 5 2 4 8 5 5 7 4 3 1 2 1 3 2 4
output:
0 1
result:
wrong answer 1st lines differ - expected: '2', found: '0'