QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#622726#7780. Dark LaTeX vs. Light LaTeXCatreapTL 3ms109340kbC++237.9kb2024-10-09 01:05:002024-10-09 01:05:01

Judging History

你现在查看的是最新测评结果

  • [2024-11-25 20:53:52]
  • hack成功,自动添加数据
  • (/hack/1258)
  • [2024-10-09 01:05:01]
  • 评测
  • 测评结果:TL
  • 用时:3ms
  • 内存:109340kb
  • [2024-10-09 01:05:00]
  • 提交

answer

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <map>
#include <vector>

typedef long long ll;

const int N = 5005;

const int B1 = 7393, B2 = 2333;
const int M1 = 998244353, M2 = 1e9 + 7;

void build(char* str, int n, std::pair<int, int>* h) {
    h[0] = {0, 0};
    for (int i = 1; i <= n; i++) {
        h[i].first = ((ll)h[i - 1].first * B1 + str[i]) % M1;
    }
    for (int i = 1; i <= n; i++) {
        h[i].second = ((ll)h[i - 1].second * B2 + str[i]) % M2;
    }
}

std::pair<int, int> calc(std::pair<int, int>* h, int l, int r) {
    static int pow1[N], pow2[N];
    if (!pow1[0]) {
        pow1[0] = pow2[0] = 1;
        for (int i = 1; i < N; i++) {
            pow1[i] = (ll)pow1[i - 1] * B1 % M1;
        }
        for (int i = 1; i < N; i++) {
            pow2[i] = (ll)pow2[i - 1] * B2 % M2;
        }
    }
    return {((ll)h[r].first - (ll)h[l - 1].first * pow1[r - l + 1] % M1 + M1) % M1,
            ((ll)h[r].second - (ll)h[l - 1].second * pow2[r - l + 1] % M2 + M2) % M2};
}

void calcZ(char *s, int n, int *z) {
    for (int i = 1, l = 0, r = 0; i < n; ++i) {
        if (i <= r && z[i - l] < r - i + 1) {
            z[i] = z[i - l];
        } else {
            z[i] = std::max(0, r - i + 1);
            while (i + z[i] < n && s[z[i]] == s[i + z[i]])
                ++z[i];
        }
        if (i + z[i] - 1 > r)
            l = i, r = i + z[i] - 1;
    }
}

void calcNext(char *s, int n, int *nxt) {
    for (int i = 1, j = 0; i < n; i++) {
        while (j && s[i + 1] != s[j + 1]) {
            j = nxt[j];
        }
        if (s[i + 1] == s[j + 1]) {
            j++;
        }
        nxt[i + 1] = j;
    }
}

int main() {
    static char s[N], t[N];
    static int zs[N][N], zt[N][N];
    static int ns[N][N], nt[N][N];
    static int sum[N][N];
    // std::map<std::pair<int, int>, int> mp_s, mp_t;
    // static std::pair<int, int> hs[N], ht[N];
    scanf("%s%s", s + 1, t + 1);
    int n = strlen(s + 1), m = strlen(t + 1);
    ll ans = 0;
    /*
    build(s, n, hs);
    build(t, m, ht);
    for (int i = 1; i <= n; i++) {
        for (int j = i; j <= n; j++) {
            mp_s[calc(hs, i, j)]++;
        }
    }
    for (int i = 1; i <= m; i++) {
        for (int j = i; j <= m; j++) {
            mp_t[calc(ht, i, j)]++;
        }
    }
    */
    for (int i = 1; i <= n; i++) {
        calcZ(s + i, n - i + 1, zs[i]);
        calcNext(s + i - 1, n - i + 1, ns[i]);
        // printf("zs[%d] = ", i);
        for (int j = 2; i + j <= n; j++) {
            if (zs[i][j]) {
                sum[i + j - 1][i + 1]++;
                sum[i + j - 1][std::min(i + j, i + zs[i][j] + 1)]--;
            }
            // printf("%d ", zs[i][j]);
        }
        // printf("\n");
        /*
        printf("nxt[%d] =", i);
        for (int j = 1; i + j <= n + 1; j++) printf("%d ", ns[i][j]);
        printf("\n");*/
    }
    for (int j = 2; j < n; j++) {
        // printf("sum[%d] = ", j);
        for (int i = 1; i <= n; i++) {
            sum[j][i] += sum[j][i - 1];
            // printf("%d ", sum[j][i]);
        }
        // printf("\n");
        /*
        for (int i = 2; i <= j; i++) {
            auto it = mp_t.find(calc(hs, i, j));
            if (it != mp_t.end()) {
                ans += (ll)sum[j][i] * (it->second);
            }
        }
        */
        // printf("ans = %d\n", ans);
    }
    for (int l = 2; l < n; l++) {
        for (int i = 0, p = 0; i < m; i++) {
            while (p && t[i + 1] != s[l + p]) {
                p = ns[l][p];
            }
            // printf("--%d %d %d \n", l, p, i);
            if (t[i + 1] == s[l + p]) {
                ans += (ll)sum[l + p][l];
                // printf("(%d, %d) = (%d, %d): %d\n", l, l + p, i - p + 1, i + 1, (ll)sum[l + p][l]);
                int pp = ns[l][p + 1];
                while (pp && t[i + 1] == s[l + pp - 1]) {
                    ans += (ll)sum[l + pp - 1][l];
                    // printf("(%d, %d) = (%d, %d): %d\n", l, l + pp, i - pp + 1, i + 1, (ll)sum[l + pp][l]);
                    pp = ns[l][pp];
                }
                p++;
                // printf("+ %d %d %d\n", l + p, i, (ll)sum[l + p][l]);
            }
        }
        // printf("ans[%d] = %d\n", l, ans);
    }
    /*
    for (int i = 1; i <= n; i++) {
        for (int j = i; j <= n; j++) {
            auto it = mp_t.find(calc(hs, i, j));
            if (it != mp_t.end()) {
                ans += it->second;
            }
        }
    }*/
    // for (int l = 1; l <= n; l++) {
    //     for (int i = 0, p = 0; i < m; i++) {
    //         while (p && t[i + 1] != s[l + p]) {
    //             p = ns[l][p];
    //         }
    //         if (t[i + 1] == s[l + p]) {
    //             p++;
    //             ans++;
    //             printf("(%d, %d) = (%d, %d): %d\n", l, l + p, i - p, i, 1);
    //         }
    //     }
    // }
    for (int l = 1; l <= n; l++) {
        for (int i = 0, p = 0; i < m; i++) {
            while (p && t[i + 1] != s[l + p]) {
                p = ns[l][p];
            }
            // printf("--%d %d %d \n", l, p, i);
            if (t[i + 1] == s[l + p]) {
                ans += 1;
                // printf("(%d, %d) = (%d, %d): %d\n", l, l + p, i - p + 1, i + 1, 1);
                int pp = ns[l][p + 1];
                while (pp && t[i + 1] == s[l + pp - 1]) {
                    ans += 1;
                    // printf("(%d, %d) = (%d, %d): %d\n", l, l + pp, i - pp + 1, i + 1, 1);
                    pp = ns[l][pp];
                }
                
                p++;
                // printf("+ %d %d %d\n", l + p, i, (ll)sum[l + p][l]);
            }
        }
        // printf("ans[%d] = %d\n", l, ans);
    }
    // printf("ans = %d\n", ans);
    memset(sum, 0, sizeof(sum));
    for (int i = 1; i <= m; i++) {
        calcZ(t + i, m - i + 1, zt[i]);
        calcNext(t + i - 1, m - i + 1, nt[i]);
        // printf("zt[%d] = ", i);
        for (int j = 2; i + j <= m; j++) {
            if (zt[i][j]) {
                sum[i + j - 1][i + 1]++;
                sum[i + j - 1][std::min(i + j, i + zt[i][j] + 1)]--;
            }
            
            // printf("%d ", zt[i][j]);
        }
        /*
        printf("nxt[%d] = ", i);
        for (int j = 1; i + j <= m + 1; j++) printf("%d ", nt[i][j]);
        printf("\n");*/
        // printf("\n");
    }
    for (int j = 2; j < m; j++) {
        // printf("sum[%d] = ", j);
        for (int i = 1; i <= m; i++) {
            sum[j][i] += sum[j][i - 1];
            // printf("%d ", sum[j][i]);
        }
        // printf("\n");
        /*
        for (int i = 2; i <= j; i++) {
            auto it = mp_s.find(calc(ht, i, j));
            if (it != mp_s.end()) {
                ans += (ll)sum[j][i] * (it->second);
            }
        }*/
        // printf("ans = %d\n", ans);
    }

    for (int l = 2; l < m; l++) {
        for (int i = 0, p = 0; i < n; i++) {
            while (p && s[i + 1] != t[l + p]) {
                p = nt[l][p];
            }
            if (s[i + 1] == t[l + p]) {
                ans += (ll)sum[l + p][l];
                // printf("(%d, %d) = (%d, %d): %d\n", l, l + p, i - p + 1, i + 1, (ll)sum[l + p][l]);
                int pp = nt[l][p + 1];
                while (pp && s[i + 1] == t[l + pp - 1]) {
                    ans += (ll)sum[l + pp - 1][l];
                    // printf("(%d, %d) = (%d, %d): %d\n", l, l + pp, i - pp + 1, i + 1, (ll)sum[l + pp][l]);
                    pp = nt[l][pp];
                }
                p++;
            }
        }
        // printf("ans[%d] = %d\n", l,  ans);
    }
    printf("%lld\n", ans);
    return 0;
}
/*
abab
abaaab
zs[1] = 0 2 0 
zs[2] = 0 1 
zs[3] = 0 
zs[4] = 
zt[1] = 0 1 1 2 0 
zt[2] = 0 0 0 1 
zt[3] = 2 1 0 
zt[4] = 1 0 
zt[5] = 0 
zt[6] = 
*/

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 3ms
memory: 106424kb

input:

abab
ab

output:

8

result:

ok 1 number(s): "8"

Test #2:

score: 0
Accepted
time: 0ms
memory: 106508kb

input:

abab
abaaab

output:

29

result:

ok 1 number(s): "29"

Test #3:

score: 0
Accepted
time: 0ms
memory: 105864kb

input:

abcd
abcde

output:

10

result:

ok 1 number(s): "10"

Test #4:

score: 0
Accepted
time: 0ms
memory: 106408kb

input:

aaba
ba

output:

6

result:

ok 1 number(s): "6"

Test #5:

score: 0
Accepted
time: 0ms
memory: 109340kb

input:

babababaaabbaabababbbaabbbababbaaaaa
aaaabbaababbab

output:

1161

result:

ok 1 number(s): "1161"

Test #6:

score: -100
Time Limit Exceeded

input:

aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...

output:


result: