QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#729043#9566. Topologyucup-team045#RE 8ms199596kbC++204.4kb2024-11-09 16:25:012024-11-09 16:25:07

Judging History

你现在查看的是最新测评结果

  • [2024-11-09 16:25:07]
  • 评测
  • 测评结果:RE
  • 用时:8ms
  • 内存:199596kb
  • [2024-11-09 16:25:01]
  • 提交

answer

#include<iostream>
#include<cstring>
#include<vector>
#include<stdint.h>
using namespace std;
using LL = long long;
template<const int T>
struct ModInt {
    const static int mod = T;
    int x;
    ModInt(int x = 0) : x(x % mod) {}
    ModInt(long long x) : x(int(x % mod)) {} 
    int val() { return x; }
    ModInt operator + (const ModInt &a) const { int x0 = x + a.x; return ModInt(x0 < mod ? x0 : x0 - mod); }
    ModInt operator - (const ModInt &a) const { int x0 = x - a.x; return ModInt(x0 < 0 ? x0 + mod : x0); }
    ModInt operator * (const ModInt &a) const { return ModInt(1LL * x * a.x % mod); }
    ModInt operator / (const ModInt &a) const { return *this * a.inv(); }
    bool operator == (const ModInt &a) const { return x == a.x; };
    bool operator != (const ModInt &a) const { return x != a.x; };
    void operator += (const ModInt &a) { x += a.x; if (x >= mod) x -= mod; }
    void operator -= (const ModInt &a) { x -= a.x; if (x < 0) x += mod; }
    void operator *= (const ModInt &a) { x = 1LL * x * a.x % mod; }
    void operator /= (const ModInt &a) { *this = *this / a; }
    friend ModInt operator + (int y, const ModInt &a){ int x0 = y + a.x; return ModInt(x0 < mod ? x0 : x0 - mod); }
    friend ModInt operator - (int y, const ModInt &a){ int x0 = y - a.x; return ModInt(x0 < 0 ? x0 + mod : x0); }
    friend ModInt operator * (int y, const ModInt &a){ return ModInt(1LL * y * a.x % mod);}
    friend ModInt operator / (int y, const ModInt &a){ return ModInt(y) / a;}
    friend ostream &operator<<(ostream &os, const ModInt &a) { return os << a.x;}
    friend istream &operator>>(istream &is, ModInt &t){return is >> t.x;}

    ModInt pow(int64_t n) const {
        ModInt res(1), mul(x);
        while(n){
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    
    ModInt inv() const {
        int a = x, b = mod, u = 1, v = 0;
        while (b) {
            int t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        if (u < 0) u += mod;
        return u;
    }
    
};
using mint = ModInt<998244353>;

const int maxn = 5005;
vector<int> g[maxn];
int sz[maxn], dep[maxn];
mint f[maxn], mul[maxn], fact[maxn];
int n;

void dfs1(int u){
    mul[u] = 1;
    sz[u] = 1;
    for(auto j : g[u]){
        dfs1(j);
        sz[u] += sz[j];
        mul[u] *= mul[j];
    }
    mul[u] *= sz[u];
    f[u] = fact[sz[u]] / mul[u];
}

// dp[u][i] : u正好在排第i的方案数(不包括u当前子树内点)
mint dp[maxn][maxn], C[maxn][maxn];
mint ans[maxn];

mint get(int x, int y){
    if (y == 0){
        return (x == 0 ? 1 : 0);
    }
    return C[x + y - 1][y - 1];
}

void dfs2(int u){

    // cout << u << ":\n";
    // for(int i = 1; i <= n; i++) cout << dp[u][i] << " \n"[i == n];

    const int s = g[u].size();
    vector<mint> pre(s), suf(s);
    for(int i = 0; i < s; i++){
        pre[i] = suf[i] = f[g[u][i]];
    }
    for(int i = 1; i < s; i++) pre[i] *= pre[i - 1];
    for(int i = s - 2; i >= 0; i--) suf[i] *= suf[i + 1];
    for(int x = 0; x < s; x++){
        int j = g[u][x];
        mint t = (x - 1 >= 0 ? pre[x - 1] : 1) * (x + 1 < s ? suf[x + 1] : 1);
        mint sum = 0;
        for(int i = 1; i <= n; i++){
            dp[j][i] = sum;
            // [i, n - sz[u] + 1]
            sum += dp[u][i] * t * get(sz[u] - sz[j] - 1, n - sz[u] + 1 - i + 1);
        }
        // 当前点数是 n - sz[j] + 1(包括j本身)
        // 后面的点只能在i后面
        // [j, n - sz[j] + 1]
        ans[j] = dp[j][j] * f[j] * get(sz[j] - 1, n - sz[j] + 1 - j + 1);
        dfs2(j);
    }

}

int main(){

#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif

    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);

    cin >> n;
    for(int i = 2; i <= n; i++){
        int x;
        cin >> x;
        g[x].push_back(i);
    }
    for(int i = 0; i <= n; i++){
        for(int j = 0; j <= i; j++){
            if (!j) C[i][j] = 1;
            else C[i][j] = C[i - 1][j - 1] + C[i - 1][j];
        }
    }
    fact[0] = 1;
    for(int i = 1; i <= n; i++) fact[i] = fact[i - 1] * i;
    dep[1] = 1;
    dfs1(1);
    ans[1] = f[1];
    dp[1][1] = 1;
    dfs2(1);
    for(int i = 1; i <= n; i++){
        cout << ans[i] << " \n"[i == n];
    }

}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 8ms
memory: 199596kb

input:

4
1 1 2

output:

3 2 1 2

result:

ok 4 number(s): "3 2 1 2"

Test #2:

score: -100
Runtime Error

input:

9
1 1 2 2 3 3 4 5

output:


result: