QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#334630#7997. 树 V 图sycqwqWA 107ms145008kbC++144.0kb2024-02-22 10:41:442024-02-22 10:41:44

Judging History

你现在查看的是最新测评结果

  • [2024-02-22 10:41:44]
  • 评测
  • 测评结果:WA
  • 用时:107ms
  • 内存:145008kb
  • [2024-02-22 10:41:44]
  • 提交

answer

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=3005,mod=998244353;
int n,k,_,bk[maxn],f[maxn][maxn],g[maxn][maxn],vis[maxn];
vector<int> e[maxn];
int a[maxn];
void dfs1(int x,int col)
{
    vis[x]=1;
    for(auto v:e[x])
        if(!vis[v]&&a[v]==col)
            dfs1(v,col);
}
int chk()
{
    memset(vis,0,sizeof vis);
    memset(bk,0,sizeof bk);
    for(int i=1;i<=n;i++)
        if(!vis[i])
        {
            if(bk[a[i]])
                return 0;
            bk[a[i]]=1;
            dfs1(i,a[i]);
        }
    for(int i=1;i<=k;i++)
        if(!bk[i])
            return 0;
    return 1;
}
int h1[maxn],h2[maxn],sum[maxn],siz[maxn],sum1[maxn];
// h1 限制 h2 距离
int query(int *sum,int l,int r)
{
    if(l>r)
        return 0;
    if(l<=0)
        return sum[r];
    return (sum[r]-sum[l-1]+mod)%mod;
}
void dfs(int x,int fa)
{
    siz[x]=1;
    for(auto v:e[x])
    {
        if(v==fa)
            continue;
        dfs(v,x);
    }
    g[x][siz[x]]=1;
    f[x][0]=1;
    for(auto v:e[x])
    {
        if(v==fa)
            continue;
        siz[x]+=siz[v];
        if(a[v]!=a[x])
        {
            memset(sum,0,sizeof sum);
            sum[0]=f[v][0];
            for(int j=1;j<=siz[v];j++)
                sum[j]=(sum[j-1]+f[v][j])%mod;
            sum1[0]=g[x][0];
            for(int j=1;j<=siz[x];j++)
                sum1[j]=sum1[j-1]+g[x][j];
            for(int j=0;j<=siz[x];j++)
            {
                h2[j]=1ll*f[x][j]*f[v][j]%mod;
                if(a[x]<a[v])
                    (h2[j]+=1ll*f[x][j]*f[v][j-1]%mod)%=mod;
                else
                    (h2[j]+=1ll*f[x][j]*f[v][j+1]%mod)%=mod;
                // h1[j]=1ll*g[x][j]*f[v][j]%mod;
                
                h1[j]=1ll*g[x][j]*query(sum,j,siz[v])%mod;
                if(a[x]<a[v])
                    (h1[j]+=1ll*query(sum1,j,siz[x])*f[v][j-1]%mod)%=mod;
                // else
                //     (h1[j]+=1ll*f[x][j]*f[v][j+1]%mod)%=mod;
                // if(x==1&&j==0)
                // {
                //     cout<<"QWQQQ"<<query(sum,j-1+(a[v]<a[x]),siz[v])<<' '<<f[x][j]<<' '<<g[x][j]<<' '<<query(sum,j,siz[x])<<' '<<sum1[siz[x]]<<' '<<sum1[j]<<'\n';
                // }
            }
        }
        else
        {
            memset(sum,0,sizeof sum);
            sum[0]=g[v][0];
            for(int j=1;j<=siz[v];j++)
                sum[j]=(sum[j-1]+g[v][j])%mod;
            sum1[0]=g[x][0];
            for(int j=1;j<=siz[x];j++)
                sum1[j]=(sum1[j-1]+g[x][j])%mod;
            for(int j=0;j<=siz[x];j++)
            {
                // if(j!=0)
                (h1[j]+=1ll*g[v][j+1]*query(sum1,j,siz[x])%mod)%=mod;
                (h1[j]+=1ll*g[x][j]*query(sum,j+2,siz[v])%mod)%=mod;
                if(j)
                    (h2[j]+=1ll*f[v][j-1]*query(sum1,j,siz[x]))%=mod;
                (h2[j]+=1ll*f[x][j]*query(sum,j+1,siz[v]))%=mod;
            }
        }
        memcpy(g[x],h1,sizeof(h1));
        memcpy(f[x],h2,sizeof(h2));
        memset(h1,0,sizeof h1);
        memset(h2,0,sizeof h2);
    }
    // for(int i=0;i<=siz[x];i++)
    // {
    //     cout<<"###"<<x<<' '<<i<<' '<<f[x][i]<<'\n';
    // }
}
signed main()
{
    // freopen("in.in","r",stdin);
    // freopen("out.out","w",stdout);
    ios::sync_with_stdio(0),cin.tie(0);
    cin>>_;
    
    while(_--)
    {
        memset(f,0,sizeof(f));
        memset(g,0,sizeof(g));
        cin>>n>>k; 
        for(int i=1;i<=n;i++)
            e[i].clear();
        for(int i=1;i<n;i++)
        {
            int x,y;
            cin>>x>>y;
            e[x].push_back(y),e[y].push_back(x);
        }
        for(int i=1;i<=n;i++)
            cin>>a[i];
        if(!chk())
        {
            cout<<"QQQ"<<0<<'\n';
            continue;
        }
        // cerr<<"OK\n";
        dfs(1,1);
        int s=0;
        for(int i=0;i<=n;i++)
            (s+=f[1][i])%=mod;
        cout<<s<<'\n';
    }
    return 0;
}

詳細信息

Test #1:

score: 0
Wrong Answer
time: 107ms
memory: 145008kb

input:

10
15 2
10 5
3 5
12 5
10 9
11 7
3 8
2 4
7 1
15 14
8 13
15 6
2 1
4 8
11 15
1 1 1 1 2 1 1 1 2 2 1 2 1 1 1
15 3
8 11
12 8
1 3
13 15
5 9
10 13
6 12
14 4
4 9
15 5
11 10
2 14
7 2
6 3
3 2 3 2 2 3 2 1 2 1 1 3 1 2 1
15 5
1 7
5 2
11 9
6 8
13 3
14 12
3 1
8 9
5 10
10 11
5 1
12 13
10 15
11 4
3 3 3 2 3 2 1 2 2 2 ...

output:

0
0
0
3
0
0
0
0
0
0

result:

wrong answer 1st numbers differ - expected: '11', found: '0'