QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#307484#7895. Graph Partitioning 2asxziillWA 262ms13792kbC++238.7kb2024-01-18 17:51:482024-01-18 17:51:50

Judging History

你现在查看的是最新测评结果

  • [2024-01-18 17:51:50]
  • 评测
  • 测评结果:WA
  • 用时:262ms
  • 内存:13792kb
  • [2024-01-18 17:51:48]
  • 提交

answer

#include <bits/stdc++.h>

using ll = long long;

using i64 = long long;
template<class T>
constexpr T power(T a, i64 b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
 
constexpr i64 mul(i64 a, i64 b, i64 p) {
    i64 res = a * b - i64(1.L * a * b / p) * p;
    res %= p;
    if (res < 0) {
        res += p;
    }
    return res;
}
template<i64 P>
struct MLong {
    i64 x;
    constexpr MLong() : x{} {}
    constexpr MLong(i64 x) : x{norm(x % getMod())} {}
    
    static i64 Mod;
    constexpr static i64 getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(i64 Mod_) {
        Mod = Mod_;
    }
    constexpr i64 norm(i64 x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr i64 val() const {
        return x;
    }
    explicit constexpr operator i64() const {
        return x;
    }
    constexpr MLong operator-() const {
        MLong res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MLong inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MLong &operator*=(MLong rhs) & {
        x = mul(x, rhs.x, getMod());
        return *this;
    }
    constexpr MLong &operator+=(MLong rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MLong &operator-=(MLong rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MLong &operator/=(MLong rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MLong operator*(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MLong operator+(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MLong operator-(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MLong operator/(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MLong &a) {
        i64 v;
        is >> v;
        a = MLong(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MLong lhs, MLong rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MLong lhs, MLong rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
i64 MLong<0LL>::Mod = i64(1E18) + 9;
 
template<int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(i64 x) : x{norm(x % getMod())} {}
    
    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) {
        Mod = Mod_;
    }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const {
        return x;
    }
    explicit constexpr operator int() const {
        return x;
    }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
        i64 v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
int MInt<0>::Mod = 998244353;
 
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
using Z = MInt<998244353>;

constexpr int B = std::sqrt(1e5) + 1;

void solve(){
	int n, k;
	std::cin >> n >> k;
	std::vector<std::vector<int>> t(n);
	for (int i = 0; i < n - 1; i++){
		int u, v;
		std::cin >> u >> v;
		u--, v--;
		t[u].push_back(v);
		t[v].push_back(u);
	}

	std::vector<int> siz(n);
    // k <= B
	if (k <= B){
		auto dfs = [&](auto self, int u, int p)->std::vector<Z>{
			siz[u] = 1;

            //还剩连通块的大小
            std::vector<Z> dp(k + 2);
            dp[0] = 1;
			for (int v : t[u]){
				if (v == p) continue;
				auto dpv = self(self, v, u);

                std::vector<Z> f(k + 2);
				for (int i = std::min(k, siz[u] + siz[v]); i >= 0; i--){
                    //i - j >= siz[u]
                    for (int j = std::max(0, i - siz[u]); j <= std::min(i, siz[v]); j++){
                        f[i] += dp[i - j] * dpv[j];
                    }
                }

                std::swap(f, dp);
                siz[u] += siz[v];
			}

            for (int i = k; i >= 0; i--){
                dp[i + 1] = dp[i];
            }
            dp[0] = dp[k] + dp[k + 1];

            return dp;
		};

        auto res = dfs(dfs, 0, -1);
        std::cout << res[0] << "\n";
	}
    else{
        // assert(1 == -1);

        //移除几个k块,剩下的就是模(k + 1)的大小
        std::vector<std::bitset<B + 1>> vis(n);//对应是否可以组成
        auto dfs = [&](auto self, int u, int p)->std::vector<Z>{
            std::vector<Z> dp(B + 1);
            std::vector<int> fsiz(B + 1);//剩下的连通块大小
            dp[0] = 1;
            vis[u][0] = 1;

            for (int v : t[u]){
                if (v == p) continue;

                auto dpv = self(self, v, u);
                std::vector<Z> f(B + 1);
                std::vector<int> dps(B + 1);
                std::bitset<B + 1> vu;
                for (int i = std::min(B, (siz[u] + siz[v]) / k); i >= 0; i--){
                    //i - j <= siz[u] / k
                    for (int j = std::max(0, i - (siz[u] / k)); j <= std::min(siz[v] / k, i); j++){
                        f[i] += dp[i - j] * dpv[j];
                        if (vis[v][j] && vis[u][i - j]){
                            dps[i] = std::max(dps[i], fsiz[i - j] + ((siz[v] - j * k) % (k + 1)));
                            vu[i] = 1;
                        }

                    }
                }

                std::swap(dp, f);
                std::swap(dps, fsiz);
                std::swap(vu, vis[u]);
                siz[u] += siz[v];
            }
            siz[u] += 1;

            for (int i = B; i >= 0; i--){
                if (fsiz[i] + 1 > k + 1){
                    dp[i] = 0;
                    vis[u][i] = 0;
                    continue;
                }
                if (fsiz[i] + 1 == k){
                    dp[i + 1] += dp[i];
                    if (vis[u][i] == 1){
                        vis[u][i + 1] = 1;
                    }
                }
            }

            return dp;
        };

        auto res = dfs(dfs, 0, -1);
        Z ans = 0;
        for (int i = 0; i <= B; i++){
            if ((n - i * k) % (k + 1) == 0){
                ans += res[i];
            }
        }
        std::cout << ans << "\n";
    }
}

int main(){
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int t;
    std::cin >> t;
    while (t--){
    	solve();
    }
    return 0;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 0ms
memory: 3484kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:

2
1

result:

ok 2 lines

Test #2:

score: 0
Accepted
time: 67ms
memory: 5900kb

input:

5550
13 4
10 3
9 1
10 8
3 11
8 5
10 7
9 6
13 5
9 7
2 7
5 12
4 8
8 2
4 1
3 4
7 8
2 5
6 7
4 8
2 3
11 1
11 10
1 4
9 10
8 4
3 6
5 7
6 1
10 2
11 7
11 1
17 2
14 16
13 15
17 3
15 11
1 6
13 2
13 17
4 8
14 10
8 14
14 5
9 12
14 2
12 17
17 6
15 7
14 6
2 14
2 13
2 4
8 4
3 11
7 3
14 1
11 9
13 3
5 10
6 8
3 10
14 ...

output:

0
3
112
0
1
0
1
0
0
0
1
0
1
0
0
1
0
140
0
0
0
814
1
6
1
1
2
2
0
612
0
1
0
0
0
1
1
0
0
121
4536
0
0
1718
0
0
1
0
444
1
1908
1813
3
74
0
1
0
46
0
0
0
0
0
0
0
0
0
1
0
1
1
1
239
0
0
0
1
0
0
0
1
0
1
0
0
1
1
0
0
0
1
0
0
0
48
0
2
0
0
0
1
364
0
206
0
0
76
0
1
0
0
2
0
1
2
0
0
1
0
0
4
0
1
1
0
0
1
1
1
0
0
1
1
...

result:

ok 5550 lines

Test #3:

score: 0
Accepted
time: 262ms
memory: 9852kb

input:

3
99990 259
23374 69108
82204 51691
8142 67119
48537 97966
51333 44408
33147 68485
21698 86824
15746 58746
78761 86975
58449 61819
69001 68714
25787 2257
25378 14067
64899 68906
29853 31359
75920 85420
76072 11728
63836 55505
43671 98920
77281 25176
40936 66517
61029 61440
66908 52300
92101 59742
69...

output:

259200
247
207766300

result:

ok 3 lines

Test #4:

score: -100
Wrong Answer
time: 236ms
memory: 13792kb

input:

3
99822 332
11587 83046
63424 60675
63423 73718
74622 40130
5110 26562
28361 80899
30886 70318
8708 11068
34855 96504
7904 75735
31904 42745
87892 55105
82374 81319
77407 82147
91475 12343
13470 95329
58766 95716
83232 44156
75907 92437
69785 93598
47857 33018
62668 31394
24238 72675
98254 43583
180...

output:

0
1077840
549665160

result:

wrong answer 1st lines differ - expected: '315881300', found: '0'