「POJ 3415」Common Substrings

题目链接:POJ 3415

字符串 $T$ 的子串定义为:

$$ T(i,k) = T_iT_{i + 1}\cdots T_{t + k - 1}, 1 \le i \le i + k - 1 \le \lvert T \rvert $$

给定两个字符串 $A, B$ 和一个整数 $K$,我们定义 $S$ 为三元组 $(i, j, k)$ 集合:

$$ S = \{(i, j, k) \mid k \ge K, A(i, k) = B(j, k)\} $$

你需要求出集合 $S$ 的大小 $\lvert S \rvert$。

数据范围:$1 \le \lvert A \rvert, \lvert B \rvert \le 10 ^ 5$,$1 \le K \le \min(\lvert A \rvert, \lvert B \rvert)$。


Solution

按照套路,我们将两个字符串拼接起来并建立后缀数组,按照 $height$ 分组。

接下来我们要统计每组中的后缀的最长公共前缀之和。按照 $rk$ 从小到大扫描一遍,每遇到一个 $A$ 串的后缀就和排名在其前面的 $B$ 串后缀进行统计。对 $B$ 串同样统计一遍。

直接计算的复杂度是 $\mathcal O(n ^ 2)$ 的,注意到两个后缀的 $\text{lcp}$ 是一段区间的 $height$ 的最小值,我们可以用单调栈来维护一下。栈内每个元素维护 $2$ 个值,记录当前的 $height$ 值和不小于该 $height$ 值的 $B​$ 串后缀的数量。

每次加入新的后缀就要把栈内元素进行合并,并重新计算前缀答案。具体实现详见代码。

时间复杂度:$\mathcal O(n \log n)$。


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 2e5 + 5;

int n, n1, n2, k, stk[N], num[N];
char s[N];

template <int S>
struct SuffixArray {
    static const int N = S << 1;
    int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
    void clear() {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
    }
    void radixSort() {
        for (int i = 1; i <= m; i++) bin[i] = 0;
        for (int i = 1; i <= n; i++) bin[rk[i]]++;
        for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
        for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
    }
    template <class Tp>
    void build(Tp *_a, int _n, int _m) {
        n = _n, m = _m;
        std::copy(_a + 1, _a + n + 1, a + 1);
        for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
        radixSort();
        for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
            p = 0;
            for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
            for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
            radixSort();
            std::swap(rk, tmp);
            p = rk[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
            }
        }
        for (int i = 1, k = 0; i <= n; i++) {
            k -= (k > 0);
            int j = sa[rk[i] - 1];
            for (; a[i + k] == a[j + k]; k++);
            height[rk[i]] = k;
        }
    }
};

SuffixArray<N> A;

void read(int &_n, int p) {
    scanf("%s", s + n + 1);
    _n = strlen(s + n + 1);
    s[n += _n + 1] = '#' + p;
}
int main() {
    while (scanf("%d", &k) && k) {
        n = 0;
        read(n1, 1), read(n2, 2);
        A.clear();
        A.build(s, n, 255);
        int top = 0;
        long long ans = 0, sum = 0;
        for (int i = 1; i <= n; i++) {
            if (A.height[i] < k) {
                top = sum = 0;
            } else {
                int cnt = 0;
                if (A.sa[i - 1] <= n1) {
                    cnt++;
                    sum += A.height[i] - k + 1;
                }
                for (; top && A.height[stk[top]] >= A.height[i]; top--) {
                    cnt += num[top];
                    sum -= 1LL * num[top] * (A.height[stk[top]] - A.height[i]);
                }
                stk[++top] = i, num[top] = cnt;
                if (A.sa[i] > n1) ans += sum;
            }
        }
        for (int i = 1; i <= n; i++) {
            if (A.height[i] < k) {
                top = sum = 0;
            } else {
                int cnt = 0;
                if (A.sa[i - 1] > n1) {
                    cnt++;
                    sum += A.height[i] - k + 1;
                }
                for (; top && A.height[stk[top]] >= A.height[i]; top--) {
                    cnt += num[top];
                    sum -= 1LL * num[top] * (A.height[stk[top]] - A.height[i]);
                }
                stk[++top] = i, num[top] = cnt;
                if (A.sa[i] <= n1) ans += sum;
            }
        }
        printf("%lld\n", ans);
    }
    return 0;
}

发表评论