QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#540439 | #7780. Dark LaTeX vs. Light LaTeX | AmiyaCast | TL | 1ms | 9368kb | C++14 | 7.0kb | 2024-08-31 17:01:31 | 2024-08-31 17:01:33 |
Judging History
answer
#include<bits/stdc++.h>
#define ll long long
#define pii make_pair
#define pb push_back
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define per(i,b,a) for(int i=b;i>=a;--i)
const ll inf = 1145141919810;
using namespace std;
inline ll read() {
ll x=0,f=1;
char c=getchar();
while (c<'0' || c>'9') {
if (c=='-') f=-1;
c=getchar();
}
while (c>='0' && c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
inline void print(ll x) {
if(x < 0) putchar('-'), x = -x;
if(x > 9) print(x / 10);
putchar(x % 10 + '0');
return ;
}
inline void pprint(ll x) {
print(x);
puts("");
}
const int N = 1e4 + 7;
struct SA {
int rk[N], sa[N], n, lstrk[N], lstsa[N], w, m, cnt[N], h[N], f[N][30];
char t[N];
void init(char *s) {
n = strlen(s + 1);
strcpy(t + 1, s + 1);
m = 127;
memset(cnt, 0, sizeof cnt);
memset(sa, 0, sizeof sa);
memset(rk, 0, sizeof rk);
memset(h, 0, sizeof h);
for(int i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]];
for(int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for(int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
memcpy(lstrk + 1, rk + 1, sizeof rk);//
for(int p = 0, i = 1; i <= n; ++i)
if(lstrk[sa[i]] == lstrk[sa[i - 1]])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
for(w = 1; w < n; w <<= 1, m = n) {
for(int p = 0, i = n; i >= n - w + 1; --i) lstsa[++p] = i;
for(int p = w, i = 1; i <= n; ++i)
if(sa[i] > w) lstsa[++p] = sa[i] - w;
memset(cnt, 0, sizeof cnt);
for(int i = 1; i <= n; ++i) ++cnt[rk[lstsa[i]]];
for(int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for(int i = n; i >= 1; --i) sa[cnt[rk[lstsa[i]]]--] = lstsa[i];
memcpy(lstrk + 1, rk + 1, sizeof rk);//
for(int p = 0, i = 1; i <= n; ++i)
if(lstrk[sa[i]] == lstrk[sa[i - 1]] && lstrk[sa[i] + w] == lstrk[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
}
for(int i = 1, k = 0; i <= n; ++i) {
if(rk[i] == 0) continue;
if(k) --k;
while(s[i + k] == s[sa[rk[i] - 1] + k]) ++k;
h[rk[i]] = k;
}
memset(f, 0x3f, sizeof f);
for(int i = 1; i <= n; ++i) f[i][0] = h[i];
for(int j = 1; (1 << j) <= n; ++j)
for(int i = 1; i <= n - (1 << j) + 1; ++i)
f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
void debug() {
puts("------debug_SA--------");
cout << "len = " << n << endl;
for(int i = 1; i <= n; ++i)
cout << sa[i] << " ";
puts("");
for(int i = 1; i <= n; ++i)
cout << rk[i] << " ";
puts("");
for(int i = 1; i <= n; ++i) {
printf("%2d: ", i);
for(int j = sa[i]; j <= n; ++j)
putchar(t[j]);
puts("");
}
puts("------End_Debug------");
}
int lcp(int x, int y) {
if(x == y)
return n - y + 1;
x = rk[x], y = rk[y];
if(x >= y) swap(x, y);
int k = log2(y - (x + 1) + 1);
return min(f[x + 1][k], f[y - (1 << k) + 1][k]);
}
} sa2, sa[3];
char s[N], t[N], p[N], q[N];
int n, m;
ll sum[3][N];
void init() {
scanf("%s", s + 1);
scanf("%s", t + 1);
n = strlen(s + 1);
m = strlen(t + 1);
//s#t
rep(i, 1, n) {
p[i] = s[i];
}
p[n + 1] = '#';
rep(i, 1, m) {
p[i + n + 1] = t[i];
}
sa2.init(p);
// cout << p + 1 << endl;
//s, t
sa[1].init(s);
sa[2].init(t);
// puts("FINISH SA INIT");
rep(i, 1, n + 1 + m) {//s#t
if(sa2.sa[i] >= n + 2) {
sum[2][i]++;
}
if(sa2.sa[i] <= n) {
sum[1][i]++;
}
sum[1][i] += sum[1][i - 1];
sum[2][i] += sum[2][i - 1];
}
// sa2.debug();
// puts("FINISH INIT");
}
ll get1(int id, int len) {//1中出现了几次
if(len <= 0) return 0;
int hd = 0, tl = 0, l, r;
l = 1, r = sa2.rk[id];
while(l < r) {
const int mid = l + r >> 1;
if(sa2.lcp(sa2.sa[mid], id) >= len) { //继续往前走走
r = mid;
} else {
l = mid + 1;
}
}
hd = l;
l = sa2.rk[id], r = n + m + 1;
while(l < r) {
const int mid = l + r + 1 >> 1;
if(sa2.lcp(sa2.sa[mid], id) >= len) {
l = mid;
} else {
r = mid - 1;
}
}
tl = l;
return sum[1][tl] - sum[1][hd - 1];
}
ll get(int st, int len, int opt) {
// cout << "st = " << st << ", len = " << len << ", opt = " << opt << endl;
if(len <= 0) return 0;
//opt = 1表示s中有几个t的 opt = 2表示t中有几个s
int hd = 0, tl = 0, id, l, r;
if(opt == 2) {
id = st;
} else {
id = st + n + 1;
}
l = 1, r = sa2.rk[id];
while(l < r) {
const int mid = l + r >> 1;
if(sa2.lcp(sa2.sa[mid], id) >= len) { //继续往前走走
r = mid;
} else {
l = mid + 1;
}
}
hd = l;
l = sa2.rk[id], r = n + m + 1;
while(l < r) {
const int mid = l + r + 1 >> 1;
if(sa2.lcp(sa2.sa[mid], id) >= len) {
l = mid;
} else {
r = mid - 1;
}
}
tl = l;
// cout << hd << " " << tl << endl;
return sum[opt][tl] - sum[opt][hd - 1];
}
vector <int> get_pos(int id, int len, int opt) {//找一个子串的所有出现位置
vector<int> vec;
int l, r, hd, tl;
l = 1, r = sa[opt].rk[id];
while(l < r) {
const int mid = l + r >> 1;
if(sa[opt].lcp(sa[opt].sa[mid], id) >= len) { //继续往前走走
r = mid;
} else {
l = mid + 1;
}
}
hd = l;
l = sa[opt].rk[id], r = n + m + 1;
while(l < r) {
const int mid = l + r + 1 >> 1;
if(sa[opt].lcp(sa[opt].sa[mid], id) >= len) {
l = mid;
} else {
r = mid - 1;
}
}
tl = l;
rep(i, hd, tl) {
vec.pb(sa[opt].sa[i]);
}
sort(vec.begin(), vec.end());
return vec;
}
void slv() {
ll f = 1;
ll ans = 0;
//先搞一样长的
for(int i = 1; i <= n; ++i) {
int lcp = sa[1].h[i];
for(int k = lcp + 1; k <= n - sa[1].sa[i] + 1; ++k) {
ans += get(sa[1].sa[i], k, 2) * get1(sa[1].sa[i], k);
// cout << ans << endl;
}
}
// pprint(ans);
//s更长的
for(int i = 1; i <= n; ++i) { //遍历s的每一个子串
int lcp = sa[1].h[i];//st = lcp + 1
for(int k = lcp + 1; k <= n - sa[1].sa[i] + 1; ++k) {
// cout << "Substring: " << sa[1].sa[i] << " " << sa[1].sa[i] + k - 1 << endl;
vector<int> vec = get_pos(sa[1].sa[i], k, 1);//找到了这个子串的所有出现位置,k是长度
// cout << "Pos :" << endl;
// for(auto x: vec) {
// cout << x << " ";
// }
// puts("");
rep(j, 0, (int)vec.size() - 2) { //看这两个子串
rep(jj, j + 1, (int)vec.size() - 1) {
int x = vec[j], y = vec[jj];
int len = y - 1 - (x + k) + 1;
ans += get(x + k, len, 2);
}
}
// cout << ans << endl;
}
}
// pprint(ans);
//t更长的
for(int i = 1; i <= m; ++i) {
int lcp = sa[2].h[i];
for(int k = lcp + 1; k <= m - sa[2].sa[i] + 1; ++k) {
// cout << "Substring: " << sa[2].sa[i] << " " << sa[2].sa[i] + k - 1 << endl;
vector<int> vec = get_pos(sa[2].sa[i], k, 2);
// cout << "Pos :" << endl;
// for(auto x: vec) {
// cout << x << " ";
// }
// puts("");
if(vec.size() == 1) continue;
rep(j, 0, (int)vec.size() - 2) { //看这两个子串
rep(jj, j + 1, (int)vec.size() - 1) {
int x = vec[j], y = vec[jj];
int len = y - 1 - (x + k) + 1;
ans += get(x + k, len, 1);
}
}
// cout << ans << endl;
}
}
pprint(ans);
}
int main() {
init();
slv();
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 0ms
memory: 9368kb
input:
abab ab
output:
8
result:
ok 1 number(s): "8"
Test #2:
score: 0
Accepted
time: 0ms
memory: 9364kb
input:
abab abaaab
output:
29
result:
ok 1 number(s): "29"
Test #3:
score: 0
Accepted
time: 1ms
memory: 8648kb
input:
abcd abcde
output:
10
result:
ok 1 number(s): "10"
Test #4:
score: 0
Accepted
time: 1ms
memory: 8772kb
input:
aaba ba
output:
6
result:
ok 1 number(s): "6"
Test #5:
score: 0
Accepted
time: 0ms
memory: 8836kb
input:
babababaaabbaabababbbaabbbababbaaaaa aaaabbaababbab
output:
1161
result:
ok 1 number(s): "1161"
Test #6:
score: -100
Time Limit Exceeded
input:
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...