QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#334630 | #7997. 树 V 图 | sycqwq | WA | 107ms | 145008kb | C++14 | 4.0kb | 2024-02-22 10:41:44 | 2024-02-22 10:41:44 |
Judging History
answer
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=3005,mod=998244353;
int n,k,_,bk[maxn],f[maxn][maxn],g[maxn][maxn],vis[maxn];
vector<int> e[maxn];
int a[maxn];
void dfs1(int x,int col)
{
vis[x]=1;
for(auto v:e[x])
if(!vis[v]&&a[v]==col)
dfs1(v,col);
}
int chk()
{
memset(vis,0,sizeof vis);
memset(bk,0,sizeof bk);
for(int i=1;i<=n;i++)
if(!vis[i])
{
if(bk[a[i]])
return 0;
bk[a[i]]=1;
dfs1(i,a[i]);
}
for(int i=1;i<=k;i++)
if(!bk[i])
return 0;
return 1;
}
int h1[maxn],h2[maxn],sum[maxn],siz[maxn],sum1[maxn];
// h1 限制 h2 距离
int query(int *sum,int l,int r)
{
if(l>r)
return 0;
if(l<=0)
return sum[r];
return (sum[r]-sum[l-1]+mod)%mod;
}
void dfs(int x,int fa)
{
siz[x]=1;
for(auto v:e[x])
{
if(v==fa)
continue;
dfs(v,x);
}
g[x][siz[x]]=1;
f[x][0]=1;
for(auto v:e[x])
{
if(v==fa)
continue;
siz[x]+=siz[v];
if(a[v]!=a[x])
{
memset(sum,0,sizeof sum);
sum[0]=f[v][0];
for(int j=1;j<=siz[v];j++)
sum[j]=(sum[j-1]+f[v][j])%mod;
sum1[0]=g[x][0];
for(int j=1;j<=siz[x];j++)
sum1[j]=sum1[j-1]+g[x][j];
for(int j=0;j<=siz[x];j++)
{
h2[j]=1ll*f[x][j]*f[v][j]%mod;
if(a[x]<a[v])
(h2[j]+=1ll*f[x][j]*f[v][j-1]%mod)%=mod;
else
(h2[j]+=1ll*f[x][j]*f[v][j+1]%mod)%=mod;
// h1[j]=1ll*g[x][j]*f[v][j]%mod;
h1[j]=1ll*g[x][j]*query(sum,j,siz[v])%mod;
if(a[x]<a[v])
(h1[j]+=1ll*query(sum1,j,siz[x])*f[v][j-1]%mod)%=mod;
// else
// (h1[j]+=1ll*f[x][j]*f[v][j+1]%mod)%=mod;
// if(x==1&&j==0)
// {
// cout<<"QWQQQ"<<query(sum,j-1+(a[v]<a[x]),siz[v])<<' '<<f[x][j]<<' '<<g[x][j]<<' '<<query(sum,j,siz[x])<<' '<<sum1[siz[x]]<<' '<<sum1[j]<<'\n';
// }
}
}
else
{
memset(sum,0,sizeof sum);
sum[0]=g[v][0];
for(int j=1;j<=siz[v];j++)
sum[j]=(sum[j-1]+g[v][j])%mod;
sum1[0]=g[x][0];
for(int j=1;j<=siz[x];j++)
sum1[j]=(sum1[j-1]+g[x][j])%mod;
for(int j=0;j<=siz[x];j++)
{
// if(j!=0)
(h1[j]+=1ll*g[v][j+1]*query(sum1,j,siz[x])%mod)%=mod;
(h1[j]+=1ll*g[x][j]*query(sum,j+2,siz[v])%mod)%=mod;
if(j)
(h2[j]+=1ll*f[v][j-1]*query(sum1,j,siz[x]))%=mod;
(h2[j]+=1ll*f[x][j]*query(sum,j+1,siz[v]))%=mod;
}
}
memcpy(g[x],h1,sizeof(h1));
memcpy(f[x],h2,sizeof(h2));
memset(h1,0,sizeof h1);
memset(h2,0,sizeof h2);
}
// for(int i=0;i<=siz[x];i++)
// {
// cout<<"###"<<x<<' '<<i<<' '<<f[x][i]<<'\n';
// }
}
signed main()
{
// freopen("in.in","r",stdin);
// freopen("out.out","w",stdout);
ios::sync_with_stdio(0),cin.tie(0);
cin>>_;
while(_--)
{
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
cin>>n>>k;
for(int i=1;i<=n;i++)
e[i].clear();
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
e[x].push_back(y),e[y].push_back(x);
}
for(int i=1;i<=n;i++)
cin>>a[i];
if(!chk())
{
cout<<"QQQ"<<0<<'\n';
continue;
}
// cerr<<"OK\n";
dfs(1,1);
int s=0;
for(int i=0;i<=n;i++)
(s+=f[1][i])%=mod;
cout<<s<<'\n';
}
return 0;
}
詳細信息
Test #1:
score: 0
Wrong Answer
time: 107ms
memory: 145008kb
input:
10 15 2 10 5 3 5 12 5 10 9 11 7 3 8 2 4 7 1 15 14 8 13 15 6 2 1 4 8 11 15 1 1 1 1 2 1 1 1 2 2 1 2 1 1 1 15 3 8 11 12 8 1 3 13 15 5 9 10 13 6 12 14 4 4 9 15 5 11 10 2 14 7 2 6 3 3 2 3 2 2 3 2 1 2 1 1 3 1 2 1 15 5 1 7 5 2 11 9 6 8 13 3 14 12 3 1 8 9 5 10 10 11 5 1 12 13 10 15 11 4 3 3 3 2 3 2 1 2 2 2 ...
output:
0 0 0 3 0 0 0 0 0 0
result:
wrong answer 1st numbers differ - expected: '11', found: '0'