QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#316108 | #7895. Graph Partitioning 2 | ucup-team1055# | RE | 363ms | 26300kb | C++20 | 7.2kb | 2024-01-27 17:25:57 | 2024-01-27 17:25:58 |
Judging History
answer
#include<bits/stdc++.h>
#define rep(i,s,n) for(int i = int(s); i < int(n); i++)
#define rrep(i,s,n) for(int i = int(n) - 1; i >= int(s); i--)
#define all(v) (v).begin(), (v).end()
using ll = long long;
using ull = unsigned long long;
using ld = long double;
template<class T>
bool chmin(T &a, T b) {
if(a <= b) return false;
a = b;
return true;
}
template<class T>
bool chmax(T &a, T b) {
if(a >= b) return false;
a = b;
return true;
}
using namespace std;
template<ll m> struct modint {
using mint = modint;
ll a;
modint(ll x = 0) : a ((x % m + m) % m) {}
static constexpr ll mod (){
return m;
}
ll val() const {
return a;
}
ll& val() {
return a;
}
mint pow(ll n) const {
mint res = 1;
mint x = a;
while(n){
if (n & 1) res *= x;
x *= x;
n >>= 1;
}
return res;
}
mint inv() const {
return pow(m-2);
}
mint & operator+=(const mint rhs){
a += rhs.a;
if (a >= m) a-= m;
return *this;
}
mint & operator-=(const mint rhs){
if (a < rhs.a) a += m;
a -= rhs.a;
return *this;
}
mint & operator*=(const mint rhs){
a = a * rhs.a % m;
return *this;
}
mint & operator/=(mint rhs){
*this *= rhs.inv();
return *this;
}
friend mint operator+(const mint& lhs, const mint& rhs){
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs){
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs){
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs){
return mint(lhs) /= rhs;
}
friend bool operator==(const modint &lhs, const modint &rhs){
return lhs.a == rhs.a;
}
friend bool operator!=(const modint &lhs, const modint &rhs){
return !(lhs == rhs);
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
};
using modint998244353 = modint<998244353>;
using mint = modint998244353;
const int th = 1000;
void solve() {
int n,k;
std::cin >> n >> k;
std::vector g(n, std::vector<int>());
rep(i,0,n-1) {
int u,v;
std::cin >> u >> v;
u--; v--;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
std::vector<int> dfs_order;
std::vector<int> par(n, -1);
auto dfs = [&](auto &&self, int v) -> void {
dfs_order.emplace_back(v);
for(auto nv: g[v]) {
if(nv == par[v]) continue;
par[nv] = v;
self(self, nv);
}
};
dfs(dfs, 0);
// k is small
if(k < th) {
auto mul = [&](const std::vector<mint> &a, const std::vector<mint> &b) -> std::vector<mint> {
int s = a.size(), t = b.size();
int sz = std::min(k+2, s + t - 1);
std::vector<mint> ab(sz, 0);
rep(i,0,s) {
rep(j,0,t) {
if(i + j >= sz) break;
ab[i + j] += a[i] * b[j];
}
}
return ab;
};
std::vector dp(n, std::vector<mint>(2, 0));
for(auto v: dfs_order | std::views::reverse) {
dp[v][1] = 1;
for(auto nv: g[v]) {
if(nv == par[v]) continue;
dp[v] = mul(dp[v], dp[nv]);
}
if((int)dp[v].size() > k) {
dp[v][0] += dp[v][k];
}
if((int)dp[v].size() > k + 1) {
dp[v][0] += dp[v][k+1];
dp[v][k+1] = 0;
}
}
mint ans = dp[0][0];
std::cout << ans.val() << '\n';
}else{
vector<vector<int>> ikeru = g;
vector<int> siz(n);
vector<int> mada = {~0, 0};
vector<int> tansaku(n);
tansaku[0] = 1;
while(!mada.empty()){
int i = mada.back();
mada.pop_back();
if (i >= 0){
for (int j:ikeru[i]){
if (tansaku[j] == 0){
mada.push_back(~j);
mada.push_back(j);
tansaku[j] = 1;
}
}
}else{
i = ~i;
for(int j: ikeru[i]){
if (tansaku[j] == 2){
siz[i] += siz[j];
}
}
siz[i] += 1;
tansaku[i] = 2;
}
}
vector<map<pair<int,int>,mint>> e(n);
vector<map<pair<int,int>,mint>> dp(n);
fill(tansaku.begin(), tansaku.end(), 0);
mada.push_back(~0);
mada.push_back(0);
tansaku[0] = 1;
while(!mada.empty()){
int i = mada.back();
mada.pop_back();
if (i >= 0){
for (int j: ikeru[i]){
if (tansaku[j] == 0){
tansaku[j] = 1;
mada.push_back(~j);
mada.push_back(j);
}
}
}else{
i = ~i;
int mx_siz = -1;
int mx_ind = -1;
vector<int> dar;
ll num_k1 = siz[i]%k;
ll num_k0 = (siz[i] - (ll)(num_k1)*(k+1))/k;
for (int j: ikeru[i]){
if (tansaku[j] == 2){
if (chmax(mx_siz, (int)e[j].size())){
mx_ind = j;
}
dar.push_back(j);
}
}
swap(e[i], e[mx_ind]);
for (int j: dar){
if (j == mx_ind) continue;
for (auto [x, c]: e[j]){
e[i][x] += c;
}
}
for (int j: dar){
e[j].clear();
}
e[i][pair(0, 0)] += 1;
if (num_k0 >= 0 && num_k1 >= 0){
if (num_k0 > 0){
if (e[i].find(pair(num_k0-1, num_k1)) != e[i].end()){
dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0-1, num_k1)];
}
}
if (num_k1 > 0){
if (e[i].find(pair(num_k0, num_k1-1)) != e[i].end()){
dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0, num_k1-1)];
}
}
}
e[i][pair(0, 0)] -= 1;
for (auto [x,c]: dp[i]) e[i][x] += c;
tansaku[i] = 2;
}
}
mint ans = 0;
for (auto[x,c] : dp[0]){
ans += c;
}
cout << ans.val() << '\n';
}
}
int main() {
std::cin.tie(nullptr);
std::ios::sync_with_stdio(false);
int t;
std::cin >> t;
while(t--) {
solve();
}
}
详细
Test #1:
score: 100
Accepted
time: 0ms
memory: 3592kb
input:
2 8 2 1 2 3 1 4 6 3 5 2 4 8 5 5 7 4 3 1 2 1 3 2 4
output:
2 1
result:
ok 2 lines
Test #2:
score: 0
Accepted
time: 56ms
memory: 6140kb
input:
5550 13 4 10 3 9 1 10 8 3 11 8 5 10 7 9 6 13 5 9 7 2 7 5 12 4 8 8 2 4 1 3 4 7 8 2 5 6 7 4 8 2 3 11 1 11 10 1 4 9 10 8 4 3 6 5 7 6 1 10 2 11 7 11 1 17 2 14 16 13 15 17 3 15 11 1 6 13 2 13 17 4 8 14 10 8 14 14 5 9 12 14 2 12 17 17 6 15 7 14 6 2 14 2 13 2 4 8 4 3 11 7 3 14 1 11 9 13 3 5 10 6 8 3 10 14 ...
output:
0 3 112 0 1 0 1 0 0 0 1 0 1 0 0 1 0 140 0 0 0 814 1 6 1 1 2 2 0 612 0 1 0 0 0 1 1 0 0 121 4536 0 0 1718 0 0 1 0 444 1 1908 1813 3 74 0 1 0 46 0 0 0 0 0 0 0 0 0 1 0 1 1 1 239 0 0 0 1 0 0 0 1 0 1 0 0 1 1 0 0 0 1 0 0 0 48 0 2 0 0 0 1 364 0 206 0 0 76 0 1 0 0 2 0 1 2 0 0 1 0 0 4 0 1 1 0 0 1 1 1 0 0 1 1 ...
result:
ok 5550 lines
Test #3:
score: 0
Accepted
time: 272ms
memory: 24676kb
input:
3 99990 259 23374 69108 82204 51691 8142 67119 48537 97966 51333 44408 33147 68485 21698 86824 15746 58746 78761 86975 58449 61819 69001 68714 25787 2257 25378 14067 64899 68906 29853 31359 75920 85420 76072 11728 63836 55505 43671 98920 77281 25176 40936 66517 61029 61440 66908 52300 92101 59742 69...
output:
259200 247 207766300
result:
ok 3 lines
Test #4:
score: 0
Accepted
time: 363ms
memory: 26300kb
input:
3 99822 332 11587 83046 63424 60675 63423 73718 74622 40130 5110 26562 28361 80899 30886 70318 8708 11068 34855 96504 7904 75735 31904 42745 87892 55105 82374 81319 77407 82147 91475 12343 13470 95329 58766 95716 83232 44156 75907 92437 69785 93598 47857 33018 62668 31394 24238 72675 98254 43583 180...
output:
315881300 4505040 185631154
result:
ok 3 lines
Test #5:
score: -100
Runtime Error
input:
3 99021 1000 41739 4318 72541 76341 31227 15416 49232 13808 50837 51259 74464 11157 92684 84646 95226 64673 74155 82511 33301 31373 5901 29318 38227 98893 96752 57411 35167 42401 24344 90803 6956 33753 51120 24535 29594 2646 70305 32961 93079 38070 49273 48987 62799 77986 94353 84447 74970 31546 263...