QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#525257#7512. Almost Prefix Concatenationmhw#Compile Error//C++234.7kb2024-08-20 15:03:402024-08-20 15:03:40

Judging History

你现在查看的是最新测评结果

  • [2024-08-20 15:03:40]
  • 评测
  • [2024-08-20 15:03:40]
  • 提交

answer

#include<bits/stdc++.h>
#define int long long
#define ll long long
#define pii make_pair
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define per(i,a,b) for(int i=b;i>=a;--i)
const ll inf = 1145141919810;
using namespace std;
inline ll read(){
    ll x=0,f=1;
    char c=getchar();
    while (c<'0' || c>'9'){
        if (c=='-')  f=-1;
        c=getchar();
    }
    while (c>='0' && c<='9'){
        x=x*10+c-'0';
         c=getchar();
    }
    return x*f;
}
inline void print(ll x){
	if(x < 0) putchar('-'), x = -x;
	if(x > 9) print(x / 10);
	putchar(x % 10 + '0');
	return ;
}
const int mod = 998244353;
inline void pprint(ll x){print(x); puts("");}
const int N = 2e6 + 8;
struct SA {

#undef int long long
	int rk[N << 1], sa[N], n, lstrk[N << 1], lstsa[N], w, m = 127, cnt[N], h[N], f[N][20];
#define siz n * sizeof(int)
	char t[N << 1];
	void init(char *s) {
		n = strlen(s + 1);
		strcpy(t + 1, s + 1);
		m = 127;
		memset(cnt, 0, sizeof cnt);
		memset(sa, 0, sizeof sa);
		memset(rk, 0, sizeof rk);
		memset(h, 0, sizeof h);
		for(int i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]];
		for(int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
		for(int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
		memcpy(lstrk + 1, rk + 1, siz);
		for(int p = 0, i = 1; i <= n; ++i)
			if(lstrk[sa[i]] == lstrk[sa[i - 1]])
				rk[sa[i]] = p;
			else
				rk[sa[i]] = ++p;
		for(w = 1; w < n; w <<= 1, m = n) {
			for(int p = 0, i = n; i >= n - w + 1; --i) lstsa[++p] = i;
			for(int p = w, i = 1; i <= n; ++i)
				if(sa[i] > w)  lstsa[++p] = sa[i] - w;
			memset(cnt, 0, sizeof cnt);
			for(int i = 1; i <= n; ++i) ++cnt[rk[lstsa[i]]];
			for(int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
			for(int i = n; i >= 1; --i) sa[cnt[rk[lstsa[i]]]--] = lstsa[i];
			memcpy(lstrk + 1, rk + 1, siz);
			for(int p = 0, i = 1; i <= n; ++i)
				if(lstrk[sa[i]] == lstrk[sa[i - 1]] && lstrk[sa[i] + w] == lstrk[sa[i - 1] + w])
					rk[sa[i]] = p;
				else
					rk[sa[i]] = ++p;
		}
		for(int i = 1, k = 0; i <= n; ++i) {
			if(rk[i] == 0) continue;
			if(k) --k;
			while(s[i + k] == s[sa[rk[i] - 1] + k]) ++k;
			h[rk[i]] = k;
		}
		memset(f, 0x3f, sizeof f);
		for(int i = 1; i <= n; ++i) f[i][0] = h[i];
		for(int j = 1; (1 << j) <= n; ++j)
			for(int i = 1; i <= n - (1 << j) + 1; ++i)
				f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
	}
	int lcp(int x, int y) {
		if(x == y)
			return n - y + 1;
		x = rk[x], y = rk[y];
		if(x >= y) swap(x, y);
		int k = log2(y - (x + 1) + 1);
		return min(f[x + 1][k], f[y - (1 << k) + 1][k]);
	}
}sa;

#define int long long

char s[N], tt[N], p[N];
ll n, m;
ll f[N];
struct Node{
	int l, r;
	ll add[3];
	ll w[3];
}t[N << 1];

void down(int p){
	for(int j = 0; j <= 2; ++j){
		if(t[p].add[j]){
			ll add = t[p].add[j];
			t[p << 1].w[j] = (add + t[p << 1].w[j]) % mod;
			t[p << 1 | 1].w[j] = (add + t[p << 1 | 1].w[j]) % mod;
			t[p << 1].add[j] = (add + t[p << 1].add[j]) % mod;
			t[p << 1 | 1].add[j] = (add + t[p << 1 | 1].add[j]) % mod;
			t[p].add[j] = 0;
		}
	}
}
void build(int p, int l, int r){
	t[p] = Node{l, r};
	if(l == r){
		return ;
	}
	const int mid = l + r >> 1;
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
}
void ch(int p, int x, int y, int k, int tp){
	int l = t[p].l, r = t[p].r;
	if(x <= l && r <= y){
		t[p].w[tp] = (t[p].w[tp] + k);
		t[p].add[tp] = (t[p].add[tp] + k) % mod;
		return ;
	}
	down(p);
	const int mid = l + r >> 1;
	if(x <= mid) ch(p << 1, x, y, k, tp);
	if(y >= mid + 1) ch(p << 1 | 1, x, y, k, tp);
}
ll ask(int p, int x, int tp){
	int l = t[p].l, r = t[p].r;
	if(l == r){
		return t[p].w[tp];
	}
	down(p);
	const int mid = l + r >> 1;
	if(x <= mid) return ask(p << 1, x, tp);
	else return ask(p << 1 | 1, x, tp);
}

ll dp[N][3];
ll pre[3];
int main(){
	scanf("%s", s + 1);
	scanf("%s", tt + 1);
	n = strlen(s + 1);
	m = strlen(tt + 1);
	for(int i = 1; i <= n; ++i)
		p[i] = s[i];
	p[n + 1] = '#';
	for(int i = 1; i <= m; ++i)
		p[i + n + 1] = tt[i];
	int st = n + 2;
	sa.init(p);
//	cout << p + 1 << endl;
	for(int i = 1; i <= n; ++i){
		ll lcp1 = sa.lcp(n + 2, i);
		ll st = i + lcp1 + 1;
		ll lcp2 = 0;
		if(st <= n) lcp2 = sa.lcp(n + 2 + lcp1 + 1, st);
		f[i] = min(min(n - i + 1, m), lcp1 + 1 + lcp2);
		f[i] %= mod;
	}
	build(1, 1, n);
	dp[0][0] = 1;
	dp[0][1] = 0;
	dp[0][2] = 0;
	for(int i = 1; i <= n; ++i){
		ch(1, i, i + f[i] - 1, (dp[i - 1][2] + 2 * dp[i - 1][1] % mod + dp[i - 1][0]) % mod, 2);
		ch(1, i, i + f[i] - 1, (dp[i - 1][1] + dp[i - 1][0]) % mod, 1);
		ch(1, i, i + f[i] - 1, dp[i - 1][0], 0);
		dp[i][2] = ask(1, i, 2) % mod;
		dp[i][1] = ask(1, i, 1) % mod;
		dp[i][0] = ask(1, i, 0) % mod;
 	}
	cout << dp[n][2] << endl;
	return 0;
}
/*
ababaab
aba
*/

Details

answer.code:33:12: warning: extra tokens at end of #undef directive
   33 | #undef int long long
      |            ^~~~
cc1plus: error: ‘::main’ must return ‘int’
answer.code: In function ‘int main()’:
answer.code:149:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  149 |         scanf("%s", s + 1);
      |         ~~~~~^~~~~~~~~~~~~
answer.code:150:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  150 |         scanf("%s", tt + 1);
      |         ~~~~~^~~~~~~~~~~~~~