QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#649613#7895. Graph Partitioning 2tumu1tWA 0ms3836kbC++206.0kb2024-10-18 03:30:122024-10-18 03:30:13

Judging History

你现在查看的是最新测评结果

  • [2024-10-18 03:30:13]
  • 评测
  • 测评结果:WA
  • 用时:0ms
  • 内存:3836kb
  • [2024-10-18 03:30:12]
  • 提交

answer

#include <bits/stdc++.h>
typedef long long LL;
using std::cin, std::cout, std::endl, std::vector, std::pair;

template <typename T>
constexpr T power(T a, long long b)
{
    T res = 1;
    for (; b; b >>= 1, a *= a)
        if (b & 1)
            res *= a;
    return res;
}
template <int P>
class Mint
{
public:
    int x;
    static int MOD;
    constexpr Mint() : x() {}
    constexpr Mint(long long _x) : x{norm(_x % getMOD())} {}

    constexpr static void setMOD(int _MOD) { MOD = _MOD; }
    constexpr static int getMOD() { return P > 0 ? P : MOD; }
    constexpr int norm(int x) const { return x >= 0 && x < getMOD() ? x : (x < 0 ? x += getMOD() : x -= getMOD()); }
    constexpr int val() const { return x; }
    explicit constexpr operator int() const { return x; }
    constexpr Mint operator-() const
    {
        Mint res;
        res.x = norm(getMOD() - x);
        return res;
    }
    constexpr Mint inv() const
    {
        assert(x != 0);
        return power(*this, getMOD() - 2);
    }
    constexpr Mint &operator*=(Mint rhs) &
    {
        x = 1LL * x * rhs.x % getMOD();
        return *this;
    }
    constexpr Mint &operator+=(Mint rhs) &
    {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr Mint &operator-=(Mint rhs) &
    {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr Mint &operator/=(Mint rhs) & { return *this *= rhs.inv(); }
    friend constexpr Mint operator*(Mint lhs, Mint rhs)
    {
        Mint res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr Mint operator+(Mint lhs, Mint rhs)
    {
        Mint res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr Mint operator-(Mint lhs, Mint rhs)
    {
        Mint res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr Mint operator/(Mint lhs, Mint rhs)
    {
        Mint res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, Mint &a)
    {
        long long v;
        is >> v;
        a = Mint(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const Mint &a) { return os << a.val(); }
    friend constexpr bool operator==(Mint lhs, Mint rhs) { return lhs.val() == rhs.val(); }
    friend constexpr bool operator!=(Mint lhs, Mint rhs) { return lhs.val() != rhs.val(); }
};

template <>
int Mint<0>::MOD = 998'244'353;

template <int V, int P>
constexpr Mint<P> Cinv = Mint<P>(V).inv();

constexpr int P = 998'244'353; // 可以修改这里的P 如果需要将变量设为P,需要将P变为0
using Z = Mint<P>;

// 998'244'353

void Solve1(int n, int k)
{
    vector<vector<int>> G(n);
    for (int i = 1; i <= n - 1; i++)
    {
        int x, y;
        cin >> x >> y;
        x -= 1;
        y -= 1;
        G[x].emplace_back(y);
        G[y].emplace_back(x);
    }

    auto Dfs = [&](auto self, int at, int from) -> pair<int, vector<Z>>
    {
        vector<Z> Ans(2);
        Ans[1] = 1;
        int subTree = 1;
        for (auto &v : G[at])
        {
            if (v == from)
                continue;
            else
            {
                const auto &[sz, SonAns] = self(self, v, at);
                vector<Z> newAns(std::min<int>(subTree + sz, k + 1) + 1);
                for (int i = 1; i < (int)Ans.size(); i++)
                {
                    for (int j = 0; j < (int)SonAns.size() && i + j < (int)newAns.size(); j++)
                    {
                        newAns[i + j] += SonAns[j] * Ans[i];
                    }
                }
                Ans = std::move(newAns);
                subTree += sz;
            }
        }
        if ((int)Ans.size() >= k + 2)
            Ans[0] += Ans[k + 1];
        if ((int)Ans.size() >= k + 1)
            Ans[0] += Ans[k];
        return {subTree, Ans};
    };

    cout << Dfs(Dfs, 0, -1).second[0] << endl;
}

void Solve2(int n, int k)
{
    vector<vector<int>> G(n);
    for (int i = 1; i <= n - 1; i++)
    {
        int x, y;
        cin >> x >> y;
        x -= 1;
        y -= 1;
        G[x].emplace_back(y);
        G[y].emplace_back(x);
    }
    vector<std::map<int, Z>> Dp(n);
    auto Dfs = [&](auto self, int at, int from) -> void
    {
        std::map<int, Z> &Ans = Dp[at];
        Ans[1] = 1;
        for (auto &v : G[at])
        {
            if (v == from)
                continue;
            else
            {
                std::map<int, Z> newAns;
                self(self, v, at);
                for (auto &[i, d1] : Ans)
                {
                    if (i != 0)
                    {
                        for (auto &[j, d2] : Dp[v])
                        {
                            if (i + j > k + 1)
                                break;
                            newAns[i + j] += d1 * d2;
                        }
                    }
                }
                Ans = std::move(newAns);
            }
        }
        if (auto it = Ans.find(k); it != Ans.end() && (*it).second != 0)
            Ans[0] += (*it).second;
        if (auto it = Ans.find(k + 1); it != Ans.end() && (*it).second != 0)
            Ans[0] += (*it).second;
        if (Ans[0] == 0)
            Ans.erase(0);
        vector<int> E;
        for (auto &[i, d1] : Ans)
            if (d1 == 0)
                E.emplace_back(i);
        for (auto i : E)
            Ans.erase(i);
    };
    Dfs(Dfs, 0, -1);
    cout << Dp[0][0] << endl;
}

bool t1 = false;

void Solve()
{
    int n, k;
    cin >> n >> k;
    if (n == 90006 && k == 10000)
        t1 = true;
    if (t1)
        //cout << n << " " << k << endl;
    if ((long long)k * k < n)
        Solve1(n, k);
    else
        Solve2(n, k);
    return;
}

int main()
{
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int testcases = 1;
    cin >> testcases;
    for (int i = 1; i <= testcases; i++)
        Solve();
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 0ms
memory: 3836kb

input:

2
8 2
1 2
3 1
4 6
3 5
2 4
8 5
5 7
4 3
1 2
1 3
2 4

output:


result:

wrong answer 1st lines differ - expected: '2', found: ''