QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#307482#7895. Graph Partitioning 2asxziillRE 0ms3624kbC++238.6kb2024-01-18 17:49:562024-01-18 17:49:57

Judging History

你现在查看的是最新测评结果

  • [2024-01-18 17:49:57]
  • 评测
  • 测评结果:RE
  • 用时:0ms
  • 内存:3624kb
  • [2024-01-18 17:49:56]
  • 提交

answer

#include <bits/stdc++.h>

using ll = long long;

using i64 = long long;
template<class T>
constexpr T power(T a, i64 b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
 
constexpr i64 mul(i64 a, i64 b, i64 p) {
    i64 res = a * b - i64(1.L * a * b / p) * p;
    res %= p;
    if (res < 0) {
        res += p;
    }
    return res;
}
template<i64 P>
struct MLong {
    i64 x;
    constexpr MLong() : x{} {}
    constexpr MLong(i64 x) : x{norm(x % getMod())} {}
    
    static i64 Mod;
    constexpr static i64 getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(i64 Mod_) {
        Mod = Mod_;
    }
    constexpr i64 norm(i64 x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr i64 val() const {
        return x;
    }
    explicit constexpr operator i64() const {
        return x;
    }
    constexpr MLong operator-() const {
        MLong res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MLong inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MLong &operator*=(MLong rhs) & {
        x = mul(x, rhs.x, getMod());
        return *this;
    }
    constexpr MLong &operator+=(MLong rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MLong &operator-=(MLong rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MLong &operator/=(MLong rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MLong operator*(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MLong operator+(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MLong operator-(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MLong operator/(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MLong &a) {
        i64 v;
        is >> v;
        a = MLong(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MLong lhs, MLong rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MLong lhs, MLong rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
i64 MLong<0LL>::Mod = i64(1E18) + 9;
 
template<int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(i64 x) : x{norm(x % getMod())} {}
    
    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) {
        Mod = Mod_;
    }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const {
        return x;
    }
    explicit constexpr operator int() const {
        return x;
    }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
        i64 v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
int MInt<0>::Mod = 998244353;
 
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
using Z = MInt<998244353>;

constexpr int B = std::sqrt(1e5) + 1;

void solve(){
	int n, k;
	std::cin >> n >> k;
	std::vector<std::vector<int>> t(n);
	for (int i = 0; i < n - 1; i++){
		int u, v;
		std::cin >> u >> v;
		u--, v--;
		t[u].push_back(v);
		t[v].push_back(u);
	}

	std::vector<int> siz(n);
    // k <= B
	if (k <= B){
		auto dfs = [&](auto self, int u, int p)->std::vector<Z>{
			siz[u] = 1;

            std::vector<Z> dp(k + 2);
            dp[0] = 1;
			for (int v : t[u]){
				if (v == p) continue;
				auto dpv = self(self, v, u);

                std::vector<Z> f(k + 2);
				for (int i = std::min(k, siz[u] + siz[v]); i >= 0; i--){
                    for (int j = std::max(0, i - siz[u]); j <= std::min(i, siz[v]); j++){
                        f[i] += dp[i - j] * dpv[j];
                    }
                }

                std::swap(f, dp);
                siz[u] += siz[v];
			}

            for (int i = k; i >= 0; i--){
                dp[i + 1] = dp[i];
            }
            dp[0] = dp[k] + dp[k + 1];

            return dp;
		};

        auto res = dfs(dfs, 0, -1);
        std::cout << res[0] << "\n";
	}
    else{
        // assert(1 == -1);

        //移除几个k块,剩下的就是模(k + 1)的大小
        std::vector<std::bitset<B>> vis(n);//对应是否可以组成
        auto dfs = [&](auto self, int u, int p)->std::vector<Z>{
            std::vector<Z> dp(B);
            std::vector<int> fsiz(B);//剩下的连通块大小
            dp[0] = 1;
            vis[u][0] = 1;

            for (int v : t[u]){
                if (v == p) continue;

                auto dpv = self(self, v, u);
                std::vector<Z> f(B);
                std::vector<int> dps(B);
                std::bitset<B> vu;
                for (int i = std::min(k, (siz[u] + siz[v]) / k); i >= 0; i--){
                    //i - j <= siz[u] / k
                    for (int j = std::max(0, i - (siz[u] / k)); j <= std::min(siz[v] / k, i); j++){
                        f[i] += dp[i - j] * dpv[j];
                        if (vis[v][j] && vis[u][i - j]){
                            dps[i] = std::max(dps[i], fsiz[i - j] + ((siz[v] - j * k) % (k + 1)));
                            vu[i] = 1;
                        }

                    }
                }

                std::swap(dp, f);
                std::swap(dps, fsiz);
                std::swap(vu, vis[u]);
                siz[u] += siz[v];
            }
            siz[u] += 1;

            for (int i = k; i >= 0; i--){
                if (fsiz[i] + 1 > k + 1){
                    dp[i] = 0;
                    vis[u][i] = 0;
                    continue;
                }
                if (fsiz[i] + 1 == k){
                    dp[i + 1] += dp[i];
                    if (vis[u][i] == 1){
                        vis[u][i + 1] = 1;
                    }
                }
            }

            return dp;
        };

        auto res = dfs(dfs, 0, -1);
        Z ans = 0;
        for (int i = 0; i <= k; i++){
            if ((n - i * k) % (k + 1) == 0){
                ans += res[i];
            }
        }
        std::cout << ans << "\n";
    }
}

int main(){
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int t;
    std::cin >> t;
    while (t--){
    	solve();
    }
    return 0;
}

详细

Test #1:

score: 100
Accepted
time: 0ms
memory: 3624kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:

2
1

result:

ok 2 lines

Test #2:

score: -100
Runtime Error

input:

5550
13 4
10 3
9 1
10 8
3 11
8 5
10 7
9 6
13 5
9 7
2 7
5 12
4 8
8 2
4 1
3 4
7 8
2 5
6 7
4 8
2 3
11 1
11 10
1 4
9 10
8 4
3 6
5 7
6 1
10 2
11 7
11 1
17 2
14 16
13 15
17 3
15 11
1 6
13 2
13 17
4 8
14 10
8 14
14 5
9 12
14 2
12 17
17 6
15 7
14 6
2 14
2 13
2 4
8 4
3 11
7 3
14 1
11 9
13 3
5 10
6 8
3 10
14 ...

output:

0
3
112
0
1
0
1
0
0
0
1
0
1
0
0
1
0
140
0
0
0
814
1
6
1
1
2
2
0
612
0
1
0
0
0
1
1
0
0
121
4536
0
0
1718
0
0
1
0
444
1
1908
1813
3
74
0
1
0
46
0
0
0
0
0
0
0
0
0
1
0
1
1
1
239
0
0
0
1
0
0
0
1
0
1
0
0
1
1
0
0
0
1
0
0
0
48
0
2
0
0
0
1
364
0
206
0
0
76
0
1
0
0
2
0
1
2
0
0
1
0
0
4
0
1
1
0
0
1
1
1
0
0
1
1
...

result: