QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#316180#7895. Graph Partitioning 2ucup-team1055#WA 1ms3752kbC++207.5kb2024-01-27 17:56:362024-01-27 17:56:37

Judging History

你现在查看的是最新测评结果

  • [2024-01-27 17:56:37]
  • 评测
  • 测评结果:WA
  • 用时:1ms
  • 内存:3752kb
  • [2024-01-27 17:56:36]
  • 提交

answer

#include<bits/stdc++.h>

#define rep(i,s,n) for(int i = int(s); i < int(n); i++)
#define rrep(i,s,n) for(int i = int(n) - 1; i >= int(s); i--)
#define all(v) (v).begin(), (v).end()

using ll = long long;
using ull = unsigned long long;
using ld = long double;

template<class T>
bool chmin(T &a, T b) {
    if(a <= b) return false;
    a = b;
    return true;
}

template<class T>
bool chmax(T &a, T b) {
    if(a >= b) return false;
    a = b;
    return true;
}

using namespace std;

template<ll m> struct modint {
    using mint = modint;
    ll a;
    modint(ll x = 0) : a ((x % m + m) % m) {}
    static constexpr ll mod (){
        return m;
    }
    ll val() const {
        return a;
    }
    ll& val() {
        return a;
    }
    mint pow(ll n) const {
        mint res = 1;
        mint x = a;
        while(n){
            if (n & 1) res *= x;
            x *= x;
            n >>= 1;
        }
        return res;
    }
    mint inv() const {
        return pow(m-2);
    }
    mint & operator+=(const mint rhs){
        a += rhs.a;
        if (a >= m) a-= m;
        return *this;
    }
    mint & operator-=(const mint rhs){
        if (a < rhs.a) a += m;
        a -= rhs.a;
        return *this;
    }
    mint & operator*=(const mint rhs){
        a = a * rhs.a % m;
        return *this;
    }
    mint & operator/=(mint rhs){
        *this *= rhs.inv();
        return *this;
    }
    friend mint operator+(const mint& lhs, const mint& rhs){
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs){
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs){
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs){
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const modint &lhs, const modint &rhs){
        return lhs.a == rhs.a;
    }
    friend bool operator!=(const modint &lhs, const modint &rhs){
        return !(lhs == rhs);
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
};

using modint998244353 = modint<998244353>;
using mint = modint998244353;

const int th = 1000;

void solve() {
    int n,k;
    std::cin >> n >> k;
    std::vector g(n, std::vector<int>());
    rep(i,0,n-1) {
        int u,v;
        std::cin >> u >> v;
        u--; v--;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    std::vector<int> dfs_order;
    std::vector<int> par(n, -1);
    auto dfs = [&](auto &&self, int v) -> void {
        dfs_order.emplace_back(v);
        for(auto nv: g[v]) {
            if(nv == par[v]) continue;
            par[nv] = v;
            self(self, nv);
        }
    };
    dfs(dfs, 0);
    // k is small
    if(k < th && false) {
        auto mul = [&](const std::vector<mint> &a, const std::vector<mint> &b) -> std::vector<mint> {
            int s = a.size(), t = b.size();
            int sz = std::min(k+2, s + t - 1);
            std::vector<mint> ab(sz, 0);
            rep(i,0,s) {
                rep(j,0,t) {
                    if(i + j >= sz) break;
                    ab[i + j] += a[i] * b[j];
                }
            }
            return ab;
        };
        std::vector dp(n, std::vector<mint>(2, 0));
        for(auto v: dfs_order | std::views::reverse) {
            dp[v][1] = 1;
            for(auto nv: g[v]) {
                if(nv == par[v]) continue;
                dp[v] = mul(dp[v], dp[nv]);
            }
            if((int)dp[v].size() > k) {
                dp[v][0] += dp[v][k];
            }
            if((int)dp[v].size() > k + 1) {
                dp[v][0] += dp[v][k+1];
                dp[v][k+1] = 0;
            }
        }
        mint ans = dp[0][0];
        std::cout << ans.val() << '\n';
    }else{
        vector<vector<int>> ikeru = g;
        vector<int> siz(n);
        vector<int> mada = {~0, 0};
        vector<int> tansaku(n);
        tansaku[0] = 1;
        while(!mada.empty()){
            int i = mada.back();
            mada.pop_back();
            if (i >= 0){
                for (int j:ikeru[i]){
                    if (tansaku[j] == 0){
                        mada.push_back(~j);
                        mada.push_back(j);
                        tansaku[j] = 1;
                    }
                }
            }else{
                i = ~i;
                for(int j: ikeru[i]){
                    if (tansaku[j] == 2){
                        siz[i] += siz[j];
                    }
                }
                siz[i] += 1;
                tansaku[i] = 2;
            }
        }
        vector<map<pair<int,int>,mint>> e(n);
        vector<map<pair<int,int>,mint>> dp(n);
        fill(tansaku.begin(), tansaku.end(), 0);
        mada.push_back(~0);
        mada.push_back(0);
        tansaku[0] = 1;
        while(!mada.empty()){
            int i = mada.back();
            mada.pop_back();
            if (i >= 0){
                for (int j: ikeru[i]){
                    if (tansaku[j] == 0){
                        tansaku[j] = 1;
                        mada.push_back(~j);
                        mada.push_back(j);
                    }
                }
            }else{
                i = ~i;
                int mx_siz = -1;
                int mx_ind = -1;
                vector<int> dar;
                ll num_k1 = siz[i]%k;
                ll num_k0 = (siz[i] - (ll)(num_k1)*(k+1))/k;
                for (int j: ikeru[i]){
                    if (tansaku[j] == 2){
                        if (chmax(mx_siz, (int)e[j].size())){
                            mx_ind = j;
                        }
                        dar.push_back(j);
                    }
                }
                if (mx_ind >= 0){
                    swap(e[i], e[mx_ind]);
                }
                for (int j: dar){
                    if (j == mx_ind) continue;
                    map<pair<int,int>,mint> np;
                    for (auto [x, c]: e[j]){
                        for (auto [y, d]: e[i]){
                            np[pair(x.first + y.first, x.second + y.second)] += c * d;
                        }
                    }
                    swap(e[i], np);
                }
                for (int j: dar){
                    e[j].clear();
                }
                e[i][pair(0, 0)] += 1;
                if (num_k0 >= 0 && num_k1 >= 0){
                    if (num_k0 > 0){
                        if (e[i].find(pair(num_k0-1, num_k1)) != e[i].end()){
                            dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0-1, num_k1)];
                        }
                    }
                    if (num_k1 > 0){
                        if (e[i].find(pair(num_k0, num_k1-1)) != e[i].end()){   
                            dp[i][pair(num_k0, num_k1)] += e[i][pair(num_k0, num_k1-1)];
                        }
                    }
                }
                e[i][pair(0, 0)] -= 1;
                for (auto [x,c]: dp[i]) e[i][x] += c;
                tansaku[i] = 2;
            }
        }
        mint ans = 0;
        for (auto[x,c] : dp[0]){
            ans += c;
        }
        cout << ans.val() << '\n';
    }
}

int main() {
    std::cin.tie(nullptr);
    std::ios::sync_with_stdio(false);
    int t;
    std::cin >> t;
    while(t--) {
        solve();
    }
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 1ms
memory: 3752kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:

0
1

result:

wrong answer 1st lines differ - expected: '2', found: '0'