QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#717732#7509. 01treehcywoiWA 2ms3860kbC++236.1kb2024-11-06 18:50:062024-11-06 18:50:07

Judging History

你现在查看的是最新测评结果

  • [2024-11-06 18:50:07]
  • 评测
  • 测评结果:WA
  • 用时:2ms
  • 内存:3860kb
  • [2024-11-06 18:50:06]
  • 提交

answer

#include <bits/stdc++.h>

using i64 = long long;

template<class T>
T qmi(T a, i64 b) {
  T res = 1;
  for (; b; b /= 2, a *= a) {
    if (b % 2) {
      res *= a;
    }
  }
  return res;
}

i64 mul(i64 a, i64 b, i64 p) {
  i64 res = a * b - i64(1.L * a * b / p) * p;
  res %= p;
  if (res < 0) {
    res += p;
  }
  return res;
}

template<int P>
struct modint {
  int x;
  constexpr modint() : x{} {}
  constexpr modint(i64 x) : x{norm(x % getmod())} {}

  static int mod;
  constexpr static int getmod() {
    if (P > 0) {
      return P;
    } else {
      return mod;
    }
  }
  constexpr static void setmod(int m) {
    mod = m;
  }
  constexpr int norm(int x) const {
    if (x < 0) {
      x += getmod();
    }
    if (x >= getmod()) {
      x -= getmod();
    }
    return x;
  }
  constexpr int val() const {
    return x;
  }
  explicit constexpr operator int() const {
    return x;
  }
  constexpr modint operator-() const {
    modint res;
    res.x = norm(getmod() - x);
    return res;
  }
  constexpr modint inv() const {
    assert(x != 0);
    return qmi(*this, getmod() - 2);
  }
  constexpr modint &operator*= (modint v) & {
    x = 1LL * x * v.x % getmod();
    return *this;
  }
  constexpr modint &operator+= (modint v) & {
    x = norm(x + v.x);
    return *this;
  }
  constexpr modint &operator-= (modint v) & {
    x = norm(x - v.x);
    return *this;
  }
  constexpr modint &operator/= (modint v) & {
    return *this *= v.inv();
  }
  friend constexpr modint operator- (modint a, modint b) {
    modint res = a;
    res -= b;
    return res;
  }
  friend constexpr modint operator+ (modint a, modint b) {
    modint res = a;
    res += b;
    return res;
  }
  friend constexpr modint operator* (modint a, modint b) {
    modint res = a;
    res *= b;
    return res;
  }
  friend constexpr modint operator/ (modint a, modint b) {
    modint res = a;
    res /= b;
    return res;
  }
  friend constexpr std::istream &operator>> (std::istream& is, modint& a) {
    i64 v;
    is >> v;
    a = modint(v);
    return is;
  }
  friend constexpr std::ostream &operator<< (std::ostream& os, const modint& a) {
    return os << a.val();
  }
  friend constexpr bool operator== (modint a, modint b) {
    return a.val() == b.val();
  }
  friend constexpr bool operator!= (modint a, modint b) {
    return a.val() != b.val();
  }
};

constexpr int P = 998244353;
using mint = modint<P>;

struct Comb {
  int n;
  std::vector<mint> fact;
  std::vector<mint> invefact;
  std::vector<mint> inve;

  Comb() : n{0}, fact{1}, invefact{1}, inve{0} {}
  Comb(int n) : Comb() {
    init(n);
  }
  
  void init(int m) {
    if (m <= n) return;
    fact.resize(m + 1);
    invefact.resize(m + 1);
    inve.resize(m + 1);
    
    for (int i = n + 1; i <= m; i++) {
      fact[i] = fact[i - 1] * i;
    }
    invefact[m] = fact[m].inv();
    for (int i = m; i > n; i--) {
      invefact[i - 1] = invefact[i] * i;
      inve[i] = invefact[i] * fact[i - 1];
    }
    n = m;
  }
  
  mint fac(int m) {
    if (m > n) init(2 * m);
    return fact[m];
  }
  mint invfac(int m) {
    if (m > n) init(2 * m);
    return invefact[m];
  }
  mint inv(int m) {
    if (m > n) init(2 * m);
    return inve[m];
  }
  mint binom(int n, int m) {
    if (n < m || m < 0) return 0;
    return fac(n) * invfac(m) * invfac(n - m);
  }
} comb;

struct Mo {
  int X, Y;
  int ix, iy;
  mint ans;

  Mo(int A, int B) {
    X = A, Y = B;
    ix = 0, iy = 0;
    calc();
  }

  void calc() {
    ans = 0;
    for (int i = 0; i <= iy; i++) {
      ans += comb.binom(ix, i) * comb.binom(X - ix, Y - i);
    }
  }

  void move(int x, int y) {
    while (iy < y) {
      iy++;
      ans += comb.binom(ix, iy) * comb.binom(X - ix, Y - iy);
    }
    while (iy > y) {
      ans -= comb.binom(ix, iy) * comb.binom(X - ix, Y - iy);
      iy--;
    }
    while (ix < x) {
      if (ix < 0) {
        calc();
      }
      ans -= comb.binom(ix, iy) * comb.binom(X - 1 - ix, Y - 1 - iy);
      ix++;
    }
    while (ix > x) {
      if (ix == 0) {
        calc();
      }
      ix--;
      ans += comb.binom(ix, iy) * comb.binom(X - 1 - ix, Y - 1 - iy);
    }
  }
};

void solve() {
  int n;
  std::cin >> n;

  std::vector<std::vector<int>> adj(n);
  for (int i = 0; i < n - 1; i++) {
    int u, v;
    std::cin >> u >> v;
    u--;
    v--;
    adj[u].push_back(v);
    adj[v].push_back(u);
  }

  std::string a, b;
  std::cin >> a >> b;

  std::vector<int> s1(n), t1(n), sq(n), tq(n);

  auto dfs = [&](auto self, int x, int p, int d) -> void {
    if (a[x] == '?') {
      sq[x]++;
    } else if (a[x] == d + '0') {
      s1[x]++;
    }
    if (b[x] == '?') {
      tq[x]++;
    } else if (b[x] == d + '0') {
      t1[x]++;
    }

    for (auto y : adj[x]) {
      if (y == p) {
        continue;
      }
      self(self, y, x, d ^ 1);
      sq[x] += sq[y];
      s1[x] += s1[y];
      tq[x] += tq[y];
      t1[x] += t1[y];
    }
  };
  dfs(dfs, 0, -1, 0);

  int X = sq[0] + tq[0];
  int Y = sq[0] + s1[0] - t1[0];

  std::vector<std::array<int, 2>> ask;
  for (int i = 0; i < n; i++) {
    int x = sq[i] + tq[i];
    int y = sq[i] + s1[i] - t1[i];
    ask.push_back({x, y});
  }

  constexpr int B = 710;
  std::sort(ask.begin(), ask.end(),
    [&](std::array<int, 2> i, std::array<int, 2> j) {
      if (i[0] / B != j[0] / B) {
        return i[0] < j[0];
      }
      return i[1] < j[1];
    });

  mint ans = 0;
  int ix = 0, iy = 0;
  Mo m1(X, Y), m2(X - 1, Y - 1);
  for (auto [x, y] : ask) {
    m1.move(x, y);
    m2.move(x - 1, y - 1);

    mint G = m1.ans, H = m2.ans;
    ans += 2 * (y * G - x * H) + x * comb.binom(X - 1, Y - 1) - y * comb.binom(X, Y);
  }

  std::cout << ans << "\n";
}

int main() {
  // freopen("opt.in", "r", stdin);
  // freopen("opt.out", "w", stdout);
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  int t;
  std::cin >> t;

  while (t--) {
    solve();
  }

  return 0;
}

詳細信息

Test #1:

score: 100
Accepted
time: 0ms
memory: 3860kb

input:

3
2
1 2
00
11
3
1 2
2 3
???
???
3
1 2
2 3
??1
0?0

output:

1
16
1

result:

ok 3 number(s): "1 16 1"

Test #2:

score: -100
Wrong Answer
time: 2ms
memory: 3612kb

input:

1000
23
1 2
1 3
1 4
2 5
5 6
4 7
3 8
4 9
8 10
8 11
8 12
1 13
7 14
10 15
7 16
7 17
5 18
18 19
12 20
9 21
21 22
6 23
00?10?0000??1?00111?010
011??1?10?01?110?0??101
6
1 2
1 3
1 4
4 5
3 6
000?10
1???01
25
1 2
2 3
2 4
4 5
5 6
2 7
4 8
5 9
7 10
8 11
11 12
5 13
11 14
3 15
6 16
14 17
1 18
4 19
6 20
4 21
5 22...

output:

211417
48
80044
8948
672
110
2871
28
3014
11760
112
2
13
988
9480
11918
2454528
0
402110
7
208
176487
123552
38
161519
29402
8137256
134485620
4538
3
1280708
10932
674142
11718
170
478
39
9095816
100644
51596
7
0
2
21126
5740928
801001
656425
587637
122088
160
5084
4611242
82142
4579994
7203944
2995...

result:

wrong answer 1st numbers differ - expected: '53545', found: '211417'