QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#312149 | #7997. 树 V 图 | chenxinyang2006 | WA | 0ms | 74432kb | C++14 | 5.4kb | 2024-01-23 14:28:59 | 2024-01-23 14:29:00 |
Judging History
answer
#include <bits/stdc++.h>
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define uint unsigned int
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define mkp make_pair
#define eb emplace_back
#define SZ(S) (int)S.size()
#define mod 998244353
//#define mod 1000000007
#define inf 0x3f3f3f3f
#define linf 0x3f3f3f3f3f3f3f3f
using namespace std;
template <class T>
void chkmax(T &x,T y){
if(x < y) x = y;
}
template <class T>
void chkmin(T &x,T y){
if(x > y) x = y;
}
inline int popcnt(int x){
return __builtin_popcount(x);
}
inline int ctz(int x){
return __builtin_ctz(x);
}
template <int P>
class mod_int
{
using Z = mod_int;
private:
static int mo(int x) { return x < 0 ? x + P : x; }
public:
int x;
int val() const { return x; }
mod_int() : x(0) {}
template <class T>
mod_int(const T &x_) : x(x_ >= 0 && x_ < P ? static_cast<int>(x_) : mo(static_cast<int>(x_ % P))) {}
bool operator==(const Z &rhs) const { return x == rhs.x; }
bool operator!=(const Z &rhs) const { return x != rhs.x; }
Z operator-() const { return Z(x ? P - x : 0); }
Z pow(long long k) const
{
Z res = 1, t = *this;
while (k)
{
if (k & 1)
res *= t;
if (k >>= 1)
t *= t;
}
return res;
}
Z &operator++()
{
x < P - 1 ? ++x : x = 0;
return *this;
}
Z &operator--()
{
x ? --x : x = P - 1;
return *this;
}
Z operator++(int)
{
Z ret = x;
x < P - 1 ? ++x : x = 0;
return ret;
}
Z operator--(int)
{
Z ret = x;
x ? --x : x = P - 1;
return ret;
}
Z inv() const { return pow(P - 2); }
Z &operator+=(const Z &rhs)
{
(x += rhs.x) >= P && (x -= P);
return *this;
}
Z &operator-=(const Z &rhs)
{
(x -= rhs.x) < 0 && (x += P);
return *this;
}
Z operator-()
{
return -x;
}
Z &operator*=(const Z &rhs)
{
x = 1ULL * x * rhs.x % P;
return *this;
}
Z &operator/=(const Z &rhs) { return *this *= rhs.inv(); }
#define setO(T, o) \
friend T operator o(const Z &lhs, const Z &rhs) \
{ \
Z res = lhs; \
return res o## = rhs; \
}
setO(Z, +) setO(Z, -) setO(Z, *) setO(Z, /)
#undef setO
friend istream& operator>>(istream& is, mod_int& x)
{
long long tmp;
is >> tmp;
x = tmp;
return is;
}
friend ostream& operator<<(ostream& os, const mod_int& x)
{
os << x.val();
return os;
}
};
using Z = mod_int<mod>;
Z power(Z p,ll k){
Z ans = 1;
while(k){
if(k % 2 == 1) ans *= p;
p *= p;
k /= 2;
}
return ans;
}
int T,n,k;
int _u[3005],_v[3005];
int a[3005],occ[3005];
int cnt;
int head[3005];
struct eg{
int to,nxt;
}edge[6005];
void make(int u,int v){
edge[++cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
vector <int> S;
Z f[3005][3005],g[3005][3005],tmp[3005];//f[u][i] 在 u 子树内,g[u][i] 在 u 子树外
void dfs(int u,int _f){
for(int i = head[u];i;i = edge[i].nxt){
int v = edge[i].to;
if(v == _f) continue;
dfs(v,u);
}
S.clear();
for(int i = head[u];i;i = edge[i].nxt){
int v = edge[i].to;
if(v == _f) continue;
S.eb(v);
}
f[u][0] = 1;
for(int v:S){
if(a[u] == a[v]) f[u][0] *= g[v][1];
else if(a[u] < a[v]) f[u][0] *= f[v][0];
else f[u][0] *= f[v][0] + f[v][1];
}
fill(f[u] + 1,f[u] + n + 1,0);
fill(g[u] + 1,g[u] + n + 1,1);
for(int v:S){
rep(i,1,n){
if(a[u] == a[v]){
f[u][i] *= g[v][i - 1];
f[u][i] += g[u][i] * f[v][i - 1];
g[u][i] *= g[v][i - 1];
}else if(a[u] < a[v]){
f[u][i] *= f[v][i] + f[v][i - 1];
g[u][i] *= f[v][i] + f[v][i - 1];
}else{
f[u][i] *= f[v][i] + f[v][i + 1];
g[u][i] *= f[v][i] + f[v][i + 1];
}
}
}
/* rep(i,0,n) printf("%d ",f[u][i].val());
printf("\n");
rep(i,1,n) printf("%d ",g[u][i].val());
printf("\n");*/
}
void solve(){
scanf("%d%d",&n,&k);
cnt = 0;
fill(head,head + n + 1,0);
fill(occ,occ + k + 1,0);
rep(i,1,n - 1){
scanf("%d%d",&_u[i],&_v[i]);
make(_u[i],_v[i]);make(_v[i],_u[i]);
}
rep(u,1,n){
scanf("%d",&a[u]);
occ[a[u]]++;
}
rep(i,1,n - 1) if(a[_u[i]] == a[_v[i]]) occ[a[_u[i]]]--;
rep(i,1,k){
if(occ[i] > 1){
printf("0\n");
return;
}
}
dfs(1,0);
Z ans = 0;
rep(i,0,n) ans += f[1][i];
int cur = 0;
rep(i,1,k){
if(occ[i]) cur++;
else ans *= cur;
}
printf("%d\n",ans.val());
}
int main(){
// freopen("test.in","r",stdin);
scanf("%d",&T);
while(T--) solve();
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 0
Wrong Answer
time: 0ms
memory: 74432kb
input:
10 15 2 10 5 3 5 12 5 10 9 11 7 3 8 2 4 7 1 15 14 8 13 15 6 2 1 4 8 11 15 1 1 1 1 2 1 1 1 2 2 1 2 1 1 1 15 3 8 11 12 8 1 3 13 15 5 9 10 13 6 12 14 4 4 9 15 5 11 10 2 14 7 2 6 3 3 2 3 2 2 3 2 1 2 1 1 3 1 2 1 15 5 1 7 5 2 11 9 6 8 13 3 14 12 3 1 8 9 5 10 10 11 5 1 12 13 10 15 11 4 3 3 3 2 3 2 1 2 2 2 ...
output:
0 0 0 9 0 2 0 0 2 12
result:
wrong answer 1st numbers differ - expected: '11', found: '0'