QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#316108#7895. Graph Partitioning 2ucup-team1055#RE 363ms26300kbC++207.2kb2024-01-27 17:25:572024-01-27 17:25:58

Judging History

你现在查看的是最新测评结果

  • [2024-01-27 17:25:58]
  • 评测
  • 测评结果:RE
  • 用时:363ms
  • 内存:26300kb
  • [2024-01-27 17:25:57]
  • 提交

answer

#include<bits/stdc++.h>

#define rep(i,s,n) for(int i = int(s); i < int(n); i++)
#define rrep(i,s,n) for(int i = int(n) - 1; i >= int(s); i--)
#define all(v) (v).begin(), (v).end()

using ll = long long;
using ull = unsigned long long;
using ld = long double;

template<class T>
bool chmin(T &a, T b) {
    if(a <= b) return false;
    a = b;
    return true;
}

template<class T>
bool chmax(T &a, T b) {
    if(a >= b) return false;
    a = b;
    return true;
}

using namespace std;

template<ll m> struct modint {
    using mint = modint;
    ll a;
    modint(ll x = 0) : a ((x % m + m) % m) {}
    static constexpr ll mod (){
        return m;
    }
    ll val() const {
        return a;
    }
    ll& val() {
        return a;
    }
    mint pow(ll n) const {
        mint res = 1;
        mint x = a;
        while(n){
            if (n & 1) res *= x;
            x *= x;
            n >>= 1;
        }
        return res;
    }
    mint inv() const {
        return pow(m-2);
    }
    mint & operator+=(const mint rhs){
        a += rhs.a;
        if (a >= m) a-= m;
        return *this;
    }
    mint & operator-=(const mint rhs){
        if (a < rhs.a) a += m;
        a -= rhs.a;
        return *this;
    }
    mint & operator*=(const mint rhs){
        a = a * rhs.a % m;
        return *this;
    }
    mint & operator/=(mint rhs){
        *this *= rhs.inv();
        return *this;
    }
    friend mint operator+(const mint& lhs, const mint& rhs){
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs){
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs){
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs){
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const modint &lhs, const modint &rhs){
        return lhs.a == rhs.a;
    }
    friend bool operator!=(const modint &lhs, const modint &rhs){
        return !(lhs == rhs);
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    
};

using modint998244353 = modint<998244353>;
using mint = modint998244353;

const int th = 1000;

void solve() {
    int n,k;
    std::cin >> n >> k;
    std::vector g(n, std::vector<int>());
    rep(i,0,n-1) {
        int u,v;
        std::cin >> u >> v;
        u--; v--;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    std::vector<int> dfs_order;
    std::vector<int> par(n, -1);
    auto dfs = [&](auto &&self, int v) -> void {
        dfs_order.emplace_back(v);
        for(auto nv: g[v]) {
            if(nv == par[v]) continue;
            par[nv] = v;
            self(self, nv);
        }
    };
    dfs(dfs, 0);
    // k is small
    if(k < th) {
        auto mul = [&](const std::vector<mint> &a, const std::vector<mint> &b) -> std::vector<mint> {
            int s = a.size(), t = b.size();
            int sz = std::min(k+2, s + t - 1);
            std::vector<mint> ab(sz, 0);
            rep(i,0,s) {
                rep(j,0,t) {
                    if(i + j >= sz) break;
                    ab[i + j] += a[i] * b[j];
                }
            }
            return ab;
        };
        std::vector dp(n, std::vector<mint>(2, 0));
        for(auto v: dfs_order | std::views::reverse) {
            dp[v][1] = 1;
            for(auto nv: g[v]) {
                if(nv == par[v]) continue;
                dp[v] = mul(dp[v], dp[nv]);
            }
            if((int)dp[v].size() > k) {
                dp[v][0] += dp[v][k];
            }
            if((int)dp[v].size() > k + 1) {
                dp[v][0] += dp[v][k+1];
                dp[v][k+1] = 0;
            }
        }
        mint ans = dp[0][0];
        std::cout << ans.val() << '\n';
    }else{
        vector<vector<int>> ikeru = g;
        vector<int> siz(n);
        vector<int> mada = {~0, 0};
        vector<int> tansaku(n);
        tansaku[0] = 1;
        while(!mada.empty()){
            int i = mada.back();
            mada.pop_back();
            if (i >= 0){
                for (int j:ikeru[i]){
                    if (tansaku[j] == 0){
                        mada.push_back(~j);
                        mada.push_back(j);
                        tansaku[j] = 1;
                    }
                }
            }else{
                i = ~i;
                for(int j: ikeru[i]){
                    if (tansaku[j] == 2){
                        siz[i] += siz[j];
                    }
                }
                siz[i] += 1;
                tansaku[i] = 2;
            }
        }
        vector<map<pair<int,int>,mint>> e(n);
        vector<map<pair<int,int>,mint>> dp(n);
        fill(tansaku.begin(), tansaku.end(), 0);
        mada.push_back(~0);
        mada.push_back(0);
        tansaku[0] = 1;
        while(!mada.empty()){
            int i = mada.back();
            mada.pop_back();
            if (i >= 0){
                for (int j: ikeru[i]){
                    if (tansaku[j] == 0){
                        tansaku[j] = 1;
                        mada.push_back(~j);
                        mada.push_back(j);
                    }
                }
            }else{
                i = ~i;
                int mx_siz = -1;
                int mx_ind = -1;
                vector<int> dar;
                ll num_k1 = siz[i]%k;
                ll num_k0 = (siz[i] - (ll)(num_k1)*(k+1))/k;
                for (int j: ikeru[i]){
                    if (tansaku[j] == 2){
                        if (chmax(mx_siz, (int)e[j].size())){
                            mx_ind = j;
                        }
                        dar.push_back(j);
                    }
                }
                swap(e[i], e[mx_ind]);
                for (int j: dar){
                    if (j == mx_ind) continue;
                    for (auto [x, c]: e[j]){
                        e[i][x] += c;
                    }
                }
                for (int j: dar){
                    e[j].clear();
                }
                e[i][pair(0, 0)] += 1;
                if (num_k0 >= 0 && num_k1 >= 0){
                    if (num_k0 > 0){
                        if (e[i].find(pair(num_k0-1, num_k1)) != e[i].end()){
                            dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0-1, num_k1)];
                        }
                    }
                    if (num_k1 > 0){
                        if (e[i].find(pair(num_k0, num_k1-1)) != e[i].end()){   
                            dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0, num_k1-1)];
                        }
                    }
                }
                e[i][pair(0, 0)] -= 1;
                for (auto [x,c]: dp[i]) e[i][x] += c;
                tansaku[i] = 2;
            }
        }
        mint ans = 0;
        for (auto[x,c] : dp[0]){
            ans += c;
        }
        cout << ans.val() << '\n';
    }
}

int main() {
    std::cin.tie(nullptr);
    std::ios::sync_with_stdio(false);
    int t;
    std::cin >> t;
    while(t--) {
        solve();
    }
}

详细

Test #1:

score: 100
Accepted
time: 0ms
memory: 3592kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:

2
1

result:

ok 2 lines

Test #2:

score: 0
Accepted
time: 56ms
memory: 6140kb

input:

5550
13 4
10 3
9 1
10 8
3 11
8 5
10 7
9 6
13 5
9 7
2 7
5 12
4 8
8 2
4 1
3 4
7 8
2 5
6 7
4 8
2 3
11 1
11 10
1 4
9 10
8 4
3 6
5 7
6 1
10 2
11 7
11 1
17 2
14 16
13 15
17 3
15 11
1 6
13 2
13 17
4 8
14 10
8 14
14 5
9 12
14 2
12 17
17 6
15 7
14 6
2 14
2 13
2 4
8 4
3 11
7 3
14 1
11 9
13 3
5 10
6 8
3 10
14 ...

output:

0
3
112
0
1
0
1
0
0
0
1
0
1
0
0
1
0
140
0
0
0
814
1
6
1
1
2
2
0
612
0
1
0
0
0
1
1
0
0
121
4536
0
0
1718
0
0
1
0
444
1
1908
1813
3
74
0
1
0
46
0
0
0
0
0
0
0
0
0
1
0
1
1
1
239
0
0
0
1
0
0
0
1
0
1
0
0
1
1
0
0
0
1
0
0
0
48
0
2
0
0
0
1
364
0
206
0
0
76
0
1
0
0
2
0
1
2
0
0
1
0
0
4
0
1
1
0
0
1
1
1
0
0
1
1
...

result:

ok 5550 lines

Test #3:

score: 0
Accepted
time: 272ms
memory: 24676kb

input:

3
99990 259
23374 69108
82204 51691
8142 67119
48537 97966
51333 44408
33147 68485
21698 86824
15746 58746
78761 86975
58449 61819
69001 68714
25787 2257
25378 14067
64899 68906
29853 31359
75920 85420
76072 11728
63836 55505
43671 98920
77281 25176
40936 66517
61029 61440
66908 52300
92101 59742
69...

output:

259200
247
207766300

result:

ok 3 lines

Test #4:

score: 0
Accepted
time: 363ms
memory: 26300kb

input:

3
99822 332
11587 83046
63424 60675
63423 73718
74622 40130
5110 26562
28361 80899
30886 70318
8708 11068
34855 96504
7904 75735
31904 42745
87892 55105
82374 81319
77407 82147
91475 12343
13470 95329
58766 95716
83232 44156
75907 92437
69785 93598
47857 33018
62668 31394
24238 72675
98254 43583
180...

output:

315881300
4505040
185631154

result:

ok 3 lines

Test #5:

score: -100
Runtime Error

input:

3
99021 1000
41739 4318
72541 76341
31227 15416
49232 13808
50837 51259
74464 11157
92684 84646
95226 64673
74155 82511
33301 31373
5901 29318
38227 98893
96752 57411
35167 42401
24344 90803
6956 33753
51120 24535
29594 2646
70305 32961
93079 38070
49273 48987
62799 77986
94353 84447
74970 31546
263...

output:


result: