QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#670569#9476. 012 Gridbulijiojiodibuliduo#AC ✓165ms23208kbC++1713.6kb2024-10-23 22:21:092024-10-23 22:21:09

Judging History

你现在查看的是最新测评结果

  • [2024-10-23 22:21:09]
  • 评测
  • 测评结果:AC
  • 用时:165ms
  • 内存:23208kb
  • [2024-10-23 22:21:09]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef vector<int> VI;
typedef basic_string<int> BI;
typedef long long ll;
typedef pair<int,int> PII;
typedef double db;
mt19937 mrand(random_device{}()); 
const ll mod=998244353;
int rnd(int x) { return mrand() % x;}
ll powmod(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
// head

template<int MOD, int RT> struct mint {
	static const int mod = MOD;
	static constexpr mint rt() { return RT; } // primitive root for FFT
	int v; explicit operator int() const { return v; } // explicit -> don't silently convert to int
	mint():v(0) {}
	mint(ll _v) { v = int((-MOD < _v && _v < MOD) ? _v : _v % MOD);
		if (v < 0) v += MOD; }
	bool operator==(const mint& o) const {
		return v == o.v; }
	friend bool operator!=(const mint& a, const mint& b) { 
		return !(a == b); }
	friend bool operator<(const mint& a, const mint& b) { 
		return a.v < b.v; }
   
	mint& operator+=(const mint& o) { 
		if ((v += o.v) >= MOD) v -= MOD; 
		return *this; }
	mint& operator-=(const mint& o) { 
		if ((v -= o.v) < 0) v += MOD; 
		return *this; }
	mint& operator*=(const mint& o) { 
		v = int((ll)v*o.v%MOD); return *this; }
	mint& operator/=(const mint& o) { return (*this) *= inv(o); }
	friend mint pow(mint a, ll p) {
		mint ans = 1; assert(p >= 0);
		for (; p; p /= 2, a *= a) if (p&1) ans *= a;
		return ans; }
	friend mint inv(const mint& a) { assert(a.v != 0); 
		return pow(a,MOD-2); }
		
	mint operator-() const { return mint(-v); }
	mint& operator++() { return *this += 1; }
	mint& operator--() { return *this -= 1; }
	friend mint operator+(mint a, const mint& b) { return a += b; }
	friend mint operator-(mint a, const mint& b) { return a -= b; }
	friend mint operator*(mint a, const mint& b) { return a *= b; }
	friend mint operator/(mint a, const mint& b) { return a /= b; }
};

const int MOD=998244353; 
using mi = mint<MOD,5>; // 5 is primitive root for both common mods

namespace simp {
	vector<mi> fac,ifac,invn;
	void check(int x) {
		if (fac.empty()) {
			fac={mi(1),mi(1)};
			ifac={mi(1),mi(1)};
			invn={mi(0),mi(1)};
		}
		while (SZ(fac)<=x) {
			int n=SZ(fac),m=SZ(fac)*2;
			fac.resize(m);
			ifac.resize(m);
			invn.resize(m);
			for (int i=n;i<m;i++) {
				fac[i]=fac[i-1]*mi(i);
				invn[i]=mi(MOD-MOD/i)*invn[MOD%i];
				ifac[i]=ifac[i-1]*invn[i];
			}
		}
	}
	mi gfac(int x) {
		if (x<0) return 0;
		check(x); return fac[x];
	}
	mi ginv(int x) {
		assert(x>0);
		check(x); return invn[x];
	}
	mi gifac(int x) {
		if (x<0) return 0;
		check(x); return ifac[x];
	}
	mi binom(int n,int m) {
		if (m < 0 || m > n) return mi(0);
		return gfac(n)*gifac(m)*gifac(n - m);
	}
	mi binombf(int n,int m) {
		if (m < 0 || m > n) return mi(0);
		if (m>n-m) m=n-m;
		mi p=1,q=1;
		for (int i=1;i<=m;i++) p=p*(n+1-i),q=q*i;
		return p/q;
	}
}

int n,m;

mi way(int x1,int y1,int x2,int y2) {
	int dx=x1-x2,dy=y2-y1;
	return simp::binom(dx+dy,dx);
}
mi gao(int x1,int y1,int x2,int y2,int x3,int y3,int x4,int y4) {
	return way(x1,y1,x2,y2)*way(x3,y3,x4,y4)-way(x1,y1,x4,y4)*way(x3,y3,x2,y2);
}
const int md = 998244353;
inline void add(int &x, int y) {
  x += y;
  if (x >= md) {
    x -= md;
  }
}

inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += md;
  }
}

inline int mul(int x, int y) {
  return (long long) x * y % md;
}

inline int power(int x, int y) {
  int res = 1;
  for (; y; y >>= 1, x = mul(x, x)) {
    if (y & 1) {
      res = mul(res, x);
    }
  }
  return res;
}

inline int inv(int a) {
  a %= md;
  if (a < 0) {
    a += md;
  }
  int b = md, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += md;
  }
  return u;
}

namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};

void init() {
  int temp = md - 1;
  max_base = 0;
  while (temp % 2 == 0) {
    temp >>= 1;
    ++max_base;
  }
  root = 2;
  while (true) {
    if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
      break;
    }
    ++root;
  }
}

void ensure_base(int nbase) {
  if (max_base == -1) {
    init();
  }
  if (nbase <= base) {
    return;
  }
  assert(nbase <= max_base);
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) {
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
  }
  roots.resize(1 << nbase);
  while (base < nbase) {
    int z = power(root, 1 << (max_base - 1 - base));
    for (int i = 1 << (base - 1); i < 1 << base; ++i) {
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = mul(roots[i], z);
    }
    ++base;
  }
}

void dft(vector<int> &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
        a[j + k] = (x + y) % md;
        a[j + k + i] = (x + md - y) % md;
      }
    }
  }
}

vector<int> multiply(vector<int> a, vector<int> b) {
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  a.resize(sz);
  b.resize(sz);
  bool equal = a == b;
  dft(a);
  if (equal) {
    b = a;
  } else {
    dft(b);
  }
  int inv_sz = inv(sz);
  for (int i = 0; i < sz; ++i) {
    a[i] = mul(mul(a[i], b[i]), inv_sz);
  }
  reverse(a.begin() + 1, a.end());
  dft(a);
  a.resize(need);
  return a;
}

vector<int> inverse(vector<int> a) {
  int n = a.size(), m = (n + 1) >> 1;
  if (n == 1) {
    return vector<int>(1, inv(a[0]));
  } else {
    vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
    int need = n << 1, nbase = 0;
    while (1 << nbase < need) {
      ++nbase;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; ++i) {
      a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    dft(a);
    a.resize(n);
    return a;
  }
}
}

using ntt::multiply;
using ntt::inverse;

vector<int>& operator += (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    add(a[i], b[i]);
  }
  return a;
}

vector<int> operator + (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c += b;
}

vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    sub(a[i], b[i]);
  }
  return a;
}

vector<int> operator - (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c -= b;
}

vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
  if (min(a.size(), b.size()) < 128) {
    vector<int> c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < c.size(); ++i) {
      for (int j = 0; j < b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}

vector<int> operator * (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c *= b;
}

vector<int>& operator /= (vector<int> &a, const vector<int> &b) {
  int n = a.size(), m = b.size();
  if (n < m) {
    a.clear();
  } else {
    vector<int> c = b;
    reverse(a.begin(), a.end());
    reverse(c.begin(), c.end());
    c.resize(n - m + 1);
    a *= inverse(c);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}

vector<int> operator / (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c /= b;
}

vector<int>& operator %= (vector<int> &a, const vector<int> &b) {
  int n = a.size(), m = b.size();
  if (n >= m) {
    vector<int> c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; ++i) {
      sub(a[i], c[i]);
    }
  }
  return a;
}

vector<int> operator % (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c %= b;
}

vector<int> derivative(const vector<int> &a) {
  int n = a.size();
  vector<int> b(n - 1);
  for (int i = 1; i < n; ++i) {
    b[i - 1] = mul(a[i], i);
  }
  return b;
}

vector<int> primitive(const vector<int> &a) {
  int n = a.size();
  vector<int> b(n + 1), invs(n + 1);
  for (int i = 1; i <= n; ++i) {
    invs[i] = i == 1 ? 1 : mul(md - md / i, invs[md % i]);
    b[i] = mul(a[i - 1], invs[i]);
  }
  return b;
}

vector<int> logarithm(const vector<int> &a) {
  vector<int> b = primitive(derivative(a) * inverse(a));
  b.resize(a.size());
  return b;
}

vector<int> exponent(const vector<int> &a) {
  vector<int> b(1, 1);
  while (b.size() < a.size()) {
    vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    add(c[0], 1);
    vector<int> old_b = b;
    b.resize(b.size() << 1);
    c -= logarithm(b);
    c *= old_b;
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = c[i];
    }
  }
  b.resize(a.size());
  return b;
}

vector<int> power(vector<int> a, int m) {
  int n = a.size(), p = -1;
  vector<int> b(n);
  for (int i = 0; i < n; ++i) {
    if (a[i]) {
      p = i;
      break;
    }
  }
  if (p == -1) {
    b[0] = !m;
    return b;
  }
  if ((long long) m * p >= n) {
    return b;
  }
  int mu = power(a[p], m), di = inv(a[p]);
  vector<int> c(n - m * p);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(a[i + p], di);
  }
  c = logarithm(c);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(c[i], m);
  }
  c = exponent(c);
  for (int i = 0; i < n - m * p; ++i) {
    b[i + m * p] = mul(c[i], mu);
  }
  return b;
}

vector<int> sqrt(const vector<int> &a) {
  vector<int> b(1, 1);
  while (b.size() < a.size()) {
    vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    vector<int> old_b = b;
    b.resize(b.size() << 1);
    c *= inverse(b);
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = mul(c[i], (md + 1) >> 1);
    }
  }
  b.resize(a.size());
  return b;
}

vector<int> multiply_all(int l, int r, vector<vector<int>> &all) {
  if (l > r) {
    return vector<int>();
  } else if (l == r) {
    return all[l];
  } else {
    int y = (l + r) >> 1;
    return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
  }
}

vector<int> evaluate(const vector<int> &f, const vector<int> &x) {
  int n = x.size();
  if (!n) {
    return vector<int>();
  }
  vector<vector<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = vector<int>{(md - x[i]) % md, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  vector<vector<int>> down(n * 2);
  down[1] = f % up[1];
  for (int i = 2; i < n * 2; ++i) {
    down[i] = down[i >> 1] % up[i];
  }
  vector<int> y(n);
  for (int i = 0; i < n; ++i) {
    y[i] = down[i + n][0];
  }
  return y;
}

vector<int> interpolate(vector<int> x, const vector<int> &y) {
  int n = x.size();
  vector<vector<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {
    x[i] %= md;
    up[i + n] = vector<int>{(md - x[i]) % md, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  vector<int> a = evaluate(derivative(up[1]), x);
  for (int i = 0; i < n; ++i) {
    a[i] = mul(y[i], inv(a[i]));
  }
  vector<vector<int>> down(n * 2);
  for (int i = 0; i < n; ++i) {
    down[i + n] = vector<int>(1, a[i]);
  }
  for (int i = n - 1; i; --i) {
    down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
  }
  return down[1];
}

mi solve00(int n,int m) {
	mi ans=0;
	VI f(n-1),g(m-1);
	rep(i,0,n-1) {
		f[i]=(int)(simp::gifac(i)*simp::gifac(i));
	}
	rep(i,0,m-1) {
		g[i]=(int)(simp::gifac(i)*simp::gifac(i));
	}
	auto h=f*g;
	for (int i=0;i<SZ(h);i++) ans+=h[i]*simp::gfac(i)*simp::gfac(i);
	rep(i,0,n-1) {
		f[i]=(int)(simp::gifac(i-1)*simp::gifac(i+1));
	}
	rep(i,0,m-1) {
		g[i]=(int)(simp::gifac(i-1)*simp::gifac(i+1));
	}
	h=f*g;
	for (int i=0;i<SZ(h);i++) ans-=h[i]*simp::gfac(i)*simp::gfac(i);
	ans=ans-way(n,1,1,m)+1;
	return ans;
}

mi solve02(int n,int m) {
	mi ans=0;
	for (int d=1;d<=m-2;d++) {
		ans+=(m-d-1)*(simp::binom(n+d-2,n-1)*simp::binom(n+d-2,n-1)-
		simp::binom(n+d-2,n-2)*simp::binom(n+d-2,n));
	}
	return ans;
}

mi solve01(int n,int m) {
	mi ans=0;
	for (int i=1;i<m;i++) {
		ans+=gao(n-1,i,0,m-1,n,i+1,1,m)-way(n-1,i,0,m-1);
	}
	return ans;
}
int main() {
	scanf("%d%d",&n,&m);
  if (n==1&&m==1) {
    printf("%d\n",3);
    return 0;
  }
	mi ans=3;
	ans+=2*(simp::binom(n+m,n)-2);
	// solve11
	ans+=simp::binom(n+m-2,n-1)*simp::binom(n+m-2,n-1)-
		simp::binom(n+m-2,n)*simp::binom(n+m-2,n-2);
	ans=ans-simp::binom(n+m-2,n-1)*2+1;
	//printf("!! %d\n",(int)ans);
	// solve00, solve22
	ans+=2*solve00(n,m);
	// solve02, solve20
	ans+=solve02(n,m)+solve02(m,n);
	// solve10
	ans+=2*(solve01(n,m)+solve01(m,n));
	printf("%d\n",(int)ans);
}

詳細信息

Test #1:

score: 100
Accepted
time: 0ms
memory: 4100kb

input:

2 2

output:

11

result:

ok "11"

Test #2:

score: 0
Accepted
time: 0ms
memory: 3828kb

input:

20 23

output:

521442928

result:

ok "521442928"

Test #3:

score: 0
Accepted
time: 134ms
memory: 23108kb

input:

200000 200000

output:

411160917

result:

ok "411160917"

Test #4:

score: 0
Accepted
time: 0ms
memory: 4104kb

input:

8 3

output:

2899

result:

ok "2899"

Test #5:

score: 0
Accepted
time: 0ms
memory: 3828kb

input:

10 9

output:

338037463

result:

ok "338037463"

Test #6:

score: 0
Accepted
time: 0ms
memory: 4088kb

input:

3 3

output:

64

result:

ok "64"

Test #7:

score: 0
Accepted
time: 0ms
memory: 3828kb

input:

9 4

output:

39733

result:

ok "39733"

Test #8:

score: 0
Accepted
time: 0ms
memory: 3896kb

input:

36 33

output:

545587245

result:

ok "545587245"

Test #9:

score: 0
Accepted
time: 0ms
memory: 3828kb

input:

35 39

output:

62117944

result:

ok "62117944"

Test #10:

score: 0
Accepted
time: 0ms
memory: 3804kb

input:

48 10

output:

264659761

result:

ok "264659761"

Test #11:

score: 0
Accepted
time: 0ms
memory: 3836kb

input:

46 30

output:

880000821

result:

ok "880000821"

Test #12:

score: 0
Accepted
time: 0ms
memory: 3808kb

input:

25 24

output:

280799864

result:

ok "280799864"

Test #13:

score: 0
Accepted
time: 0ms
memory: 3800kb

input:

17 10

output:

624958192

result:

ok "624958192"

Test #14:

score: 0
Accepted
time: 5ms
memory: 3964kb

input:

4608 9241

output:

322218996

result:

ok "322218996"

Test #15:

score: 0
Accepted
time: 4ms
memory: 4192kb

input:

3665 6137

output:

537704652

result:

ok "537704652"

Test #16:

score: 0
Accepted
time: 5ms
memory: 4208kb

input:

4192 6186

output:

971816887

result:

ok "971816887"

Test #17:

score: 0
Accepted
time: 4ms
memory: 4448kb

input:

4562 4403

output:

867628411

result:

ok "867628411"

Test #18:

score: 0
Accepted
time: 5ms
memory: 3988kb

input:

8726 3237

output:

808804305

result:

ok "808804305"

Test #19:

score: 0
Accepted
time: 5ms
memory: 3968kb

input:

5257 8166

output:

488829288

result:

ok "488829288"

Test #20:

score: 0
Accepted
time: 5ms
memory: 4036kb

input:

8013 7958

output:

215666893

result:

ok "215666893"

Test #21:

score: 0
Accepted
time: 5ms
memory: 3924kb

input:

8837 5868

output:

239261227

result:

ok "239261227"

Test #22:

score: 0
Accepted
time: 5ms
memory: 3948kb

input:

8917 5492

output:

706653412

result:

ok "706653412"

Test #23:

score: 0
Accepted
time: 5ms
memory: 4248kb

input:

9628 5378

output:

753685501

result:

ok "753685501"

Test #24:

score: 0
Accepted
time: 154ms
memory: 22348kb

input:

163762 183794

output:

141157510

result:

ok "141157510"

Test #25:

score: 0
Accepted
time: 71ms
memory: 12784kb

input:

83512 82743

output:

114622013

result:

ok "114622013"

Test #26:

score: 0
Accepted
time: 76ms
memory: 12312kb

input:

84692 56473

output:

263907717

result:

ok "263907717"

Test #27:

score: 0
Accepted
time: 38ms
memory: 8016kb

input:

31827 74195

output:

200356808

result:

ok "200356808"

Test #28:

score: 0
Accepted
time: 159ms
memory: 22456kb

input:

189921 163932

output:

845151158

result:

ok "845151158"

Test #29:

score: 0
Accepted
time: 72ms
memory: 12932kb

input:

27157 177990

output:

847356039

result:

ok "847356039"

Test #30:

score: 0
Accepted
time: 73ms
memory: 12968kb

input:

136835 39390

output:

962822638

result:

ok "962822638"

Test #31:

score: 0
Accepted
time: 63ms
memory: 12452kb

input:

118610 18795

output:

243423874

result:

ok "243423874"

Test #32:

score: 0
Accepted
time: 75ms
memory: 12588kb

input:

122070 19995

output:

531055604

result:

ok "531055604"

Test #33:

score: 0
Accepted
time: 78ms
memory: 13096kb

input:

20031 195670

output:

483162363

result:

ok "483162363"

Test #34:

score: 0
Accepted
time: 132ms
memory: 23036kb

input:

199992 199992

output:

262099623

result:

ok "262099623"

Test #35:

score: 0
Accepted
time: 165ms
memory: 23088kb

input:

200000 199992

output:

477266520

result:

ok "477266520"

Test #36:

score: 0
Accepted
time: 162ms
memory: 23208kb

input:

199999 199996

output:

165483205

result:

ok "165483205"

Test #37:

score: 0
Accepted
time: 0ms
memory: 3884kb

input:

1 1

output:

3

result:

ok "3"

Test #38:

score: 0
Accepted
time: 14ms
memory: 6268kb

input:

1 100000

output:

8828237

result:

ok "8828237"

Test #39:

score: 0
Accepted
time: 12ms
memory: 6220kb

input:

100000 2

output:

263711286

result:

ok "263711286"

Test #40:

score: 0
Accepted
time: 0ms
memory: 3812kb

input:

50 50

output:

634767411

result:

ok "634767411"

Extra Test:

score: 0
Extra Test Passed