QOJ.ac

QOJ

ID提交记录ID题目HackerOwner结果提交时间测评时间
#1235#426021#7219. The Mighty SpellASnowncaijianhongFailed.2024-11-22 15:02:522024-11-22 15:02:52

詳細信息

Extra Test:

Accepted
time: 0ms
memory: 6700kb

input:

60 3
2 3 2 2 3 1 3 1 2 3 2 2 3 1 2 2 2 1 2 3 1 3 3 2 2 2 3 3 1 2 1 1 2 1 1 1 1 1 1 1 3 2 1 2 3 3 3 1 2 2 2 2 3 3 1 1 3 2 2 2

output:

628462825

result:

ok answer is '628462825'

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#426021#7219. The Mighty SpellcaijianhongAC ✓642ms49560kbC++143.8kb2024-05-30 20:21:432024-11-22 16:49:49

answer

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, __VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint &self) {
    return os << raw(self);
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
#ifdef ONLINE_JUDGE
typedef modint<(int)1e9 + 7> mint;
#else
typedef modint<998244353> mint;
#endif
int n, m, a[200010], cnt[60][200010];
mint f[200010], b[200010], qp2[200010];
mint g(mint l) { return l != 0 ? 2 * l * l * l + 3 * l * l + 3 * l + 3 : 0; }
void init() {
  f[1] = g(1);
  mint pre = f[1];
  for (int i = 2; i <= n; i++) {
    f[i] = g(i) - g(i - 1) - pre;
    pre += f[i];
  }
  for (int c = 1; c <= m; c++) {
    for (int i = 1; i <= n; i++) cnt[c][i] = cnt[c][i - 1] + (a[i] == c);
  }
  qp2[0] = 1;
  for (int i = 1; i <= n; i++) qp2[i] = qp2[i - 1] * 2;
}
int calc(int l, int r, int c) {
  return cnt[c][n] - cnt[c][r] + cnt[c][l - 1];
}
mint solve() {
  vector<tuple<int, LL, mint, mint>> vec;
  for (int l = n; l >= 1; l--) {
    decltype(vec) nxt;
    int cntl = cnt[a[l]][n];
    {
      mint tmp = 1;
      for (int c = 1; c <= m; c++) if (c != a[l]) tmp *= qp2[calc(l, l, c)] - 1;
      nxt.emplace_back(l, 1ll << a[l], qp2[cntl], tmp);
    }
    LL lstmask = 1ll << a[l];
    for (auto [r, mask, ci, co] : vec) {
      if (lstmask == mask) continue;
      if (mask >> a[l] & 1) {
        nxt.emplace_back(r, mask, ci, co);
        lstmask = mask;
      } else {
        nxt.emplace_back(r, mask | 1ll << a[l], ci * qp2[cntl], co / (qp2[cntl] - 1));
        lstmask = mask | 1ll << a[l];
      }
    }
    vec = move(nxt);
    for (int j = 0; j < (int)vec.size(); j++) {
      int rn = j + 1 < (int)vec.size() ? get<0>(vec[j + 1]) : n + 1;
      auto [r, mask, ci, co] = vec[j];
      b[r - l + 1] += ci * co, b[rn - l + 1] -= ci * co;
    }
  }
  mint ans = 0;
  for (int i = 1; i <= n; i++) b[i] += b[i - 1], ans += b[i] * f[i] / qp2[i];
  return ans;
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);  
#endif
  cin >> n >> m;
  for (int i = 1; i <= n; i++) cin >> a[i];
  if (set<int>(a + 1, a + n + 1).size() < m) {
    cout << 0 << endl;
    return 0;
  }
  init();
  cout << solve() << endl;
  return 0;
}