「Codeforces 204E」Little Elephant and Strings

题目链接:Codeforces 204E

小象非常喜欢字符串。他拥有 $n$ 个包含小写字母的字符串,第 $i$ 个字符串记为 $a_i$。对于每个字符串 $a_i(1 \le i \le n)$,小象想要求出二元组 $(l, r)$ 的对数,其中 $(l, r)$ 需要满足:$1 \le l \le r \le \lvert a_i \rvert$ 且子串 $a_i[l\dots r]$ 是至少 $k$ 个字符串的子串。

数据范围:$1 \le n, k \le 10 ^ 5$,$\sum_{i = 1} ^ n \lvert a_i \rvert \le 10 ^ 5$。


Solution

首先我们把所有字符串平在一起求后缀数组,对于每个字符串 $a_i$ 从前往后考虑,找到当前位置 $j$ 能够往后延伸的最长长度 $len_j$,使得 $s_i[j\dots j + len_j - 1]$ 至少为 $k$ 个字符串的子串。那么这个位置 $j$ 对答案的贡献就是 $len_j$。在计算同一个字符串的下一个位置 $j + 1$ 时,我们不需要重新二分长度 $len_{j + 1}$;注意到去掉位置 $j$ 后的子串 $s_i[j + 1\dots j + len_j - 1]$ 一定是合法的,因此我们有 $len_{j + 1} \ge len_j - 1$。这个过程和求 $height$ 数组很类似。

接下来问题转化为:如何快速求出从位置 $j$ 开始的长度为 $len$ 的子串是否合法?运用 $height$ 数组分组的思想,我们将 $rk(j)$ 往前扩展到 $rk(l)$、往后扩展到 $rk(r)$,使得 $\text{LCP}(rk(l), rk(r)) \ge len$。此时我们只需要判断后缀 $\text{suffix}(sa(rk(l))), \text{suffix}(sa(rk(l) + 1)), \cdots, \text{suffix}(sa(rk(r)))$ 中是否属于不少于 $k$ 个字符串。

这个问题我们是可以在 $\mathcal O(n)$ 的时间内预处理得到的。对于每个 $i$ 预处理 $pos(i)$ 表示最小的位置 $j$ 使得 $\text{suffix}(sa(i\dots j))$ 属于不少于 $k$ 个字符串。可以通过 $\text{Two Pointers}$ 预处理。

于是上述问题只需要满足 $pos(rk(l)) \le rk(r)​$ 则可行。

时间复杂度:$\mathcal O(n \log n)​$,貌似比标算复杂度优秀耶 >_<!


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 2e5 + 5;

int n, m, k, a[N], len[N], st[N], bl[N], cnt[N], p[N];

template <int S>
struct SuffixArray {
    static const int N = S << 1, logN = 20;
    int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N], lg[N], f[N][logN];
    void clear() {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
    }
    void radixSort() {
        for (int i = 1; i <= m; i++) bin[i] = 0;
        for (int i = 1; i <= n; i++) bin[rk[i]]++;
        for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
        for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
    }
    template <class Tp>
    void build(Tp *_a, int _n, int _m) {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
        n = _n, m = _m;
        std::copy(_a + 1, _a + n + 1, a + 1);
        for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
        radixSort();
        for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
            p = 0;
            for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
            for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
            radixSort();
            std::swap(rk, tmp);
            p = rk[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
            }
        }
        for (int i = 1, k = 0; i <= n; i++) {
            k -= (k > 0);
            int j = sa[rk[i] - 1];
            for (; a[i + k] == a[j + k]; k++);
            height[rk[i]] = k;
        }
    }
    void buildST() {
        for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        for (int i = 1; i <= n; i++) f[i][0] = height[i];
        for (int j = 1; (1 << j) <= n; j++) {
            for (int i = 1; i + (1 << j) - 1 <= n; i++) {
                f[i][j] = std::min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
            }
        }
    }
    int LCP(int l, int r) {
        if (l == r) return n - sa[l] + 1;
        int k = lg[r - (++l) + 1];
        return std::min(f[l][k], f[r - (1 << k) + 1][k]);
    }
};

SuffixArray<N> A;

int calcL(int x, int len) {
    int l = 1, r = x, ans = 0;
    while (l <= r) {
        int mid = (l + r) >> 1;
        A.LCP(mid, x) >= len ? r = (ans = mid) - 1 : l = mid + 1;
    }
    return ans;
}
int calcR(int x, int len) {
    int l = x, r = n, ans = 0;
    while (l <= r) {
        int mid = (l + r) >> 1;
        A.LCP(x, mid) >= len ? l = (ans = mid) + 1 : r = mid - 1;
    }
    return ans;
}
bool check(int pos, int len) {
    int x = A.rk[pos];
    int l = calcL(x, len), r = calcR(x, len);
    return p[l] <= r;
}
int main() {
    scanf("%d%d", &m, &k);
    for (int i = 1; i <= m; i++) {
        static char t[N];
        scanf("%s", t + 1);
        len[i] = strlen(t + 1);
        st[i] = n + 1;
        for (int j = 1; j <= len[i]; j++) a[++n] = t[j], bl[n] = i;
        a[++n] = i + 256, bl[n] = 0;
    }
    A.build(a, n, m + 256);
    A.buildST();
    for (int i = 1, j = 1, now = 0; i <= n; i++) {
        for (; now < k && j <= n; j++) {
            now += (bl[A.sa[j]] && ++cnt[bl[A.sa[j]]] == 1);
        }
        p[i] = (now >= k ? j - 1 : n + 1);
        now -= (bl[A.sa[i]] && --cnt[bl[A.sa[i]]] == 0);
    }
    for (int i = 1; i <= m; i++) {
        long long ans = 0;
        for (int j = 1, l = 0; j <= len[i]; j++) {
            l -= (l > 0);
            for (; j + l - 1 < len[i] && check(st[i] + j - 1, l + 1); l++);
            ans += l;
        }
        printf("%lld%c", ans, " \n"[i == m]);
    }
    return 0;
}

发表评论