「Codeforces 1073G」Yet Another LCP Problem

题目链接:Codeforces 1073G

定义 $\text{LCP}(s, t)$ 字符串 $s$ 和 $t$ 的最长公共前缀,再定义 $s[x\dots y]$ 为字符串 $s$ 从位置 $x$ 到 $y$ 的子串。

给定一个长度为 $n$ 的字符串 $s$ 和 $q$ 个询问。每次询问给出两个长度分别为 $k_i, l_i$ 的序列 $a, b$。你需要计算 $\sum_{i = 1} ^ k \sum_{j = 1} ^ l \text{LCP}(s[a_i \dots n], s[b_j \dots n])$ 的值。

数据范围:$1 \le n, q, \sum k_i, \sum l_i \le 2 \times 10 ^ 5$,$1 \le k_i, l_i \le n$。


Solution

乍一看本题和「AHOI 2013」差异题解)非常相似;但是每次处理的复杂度为 $\mathcal O(n)$,总复杂度 $\mathcal O(nq)$ 是完全无法接受的。

但是那道题的方法还是有借鉴意义的。在建立后缀数组后,我们考虑将 $a$ 和 $b$ 都按照 $rk(i)$ 从小到大排序。

我们把贡献拆成 $2$ 部分:第一部分是 $rk(b_j) \le rk(a_i)$ 的贡献;第二部分是 $rk(b_j) > rk(a_i)$ 的贡献。这两部分贡献的本质是相同的,下文只讲述 $rk(b_j) \le rk(a_i)$ 的贡献计算方法。

首先把 $a_i$ 按排序后的顺序加入,并把所有满足 $rk(b_j) \le rk(a_i)$ 的 $b_j$ 加入。这些 $b_j$ 对答案的贡献可以用 $\text{RMQ}$ 在 $O(1)$ 的时间内求出。贡献为 $\text{lcp}(a_i, b_j)$ 的值。

接下来考虑把 $a_{i + 1}$ 加入后的影响。我们要将所有的贡献对 $\text{lcp}(a_i, a_{i + 1})$ 取 $\text{min}$,即把所有贡献大于 $\text{lcp}(a_i, a_{i + 1})$ 的贡献都改为 $\text{lcp}(a_i, a_{i + 1})$。这是权值线段树的基本操作。

时间复杂度:$\mathcal O(n \log n)$,默认所有变量同阶。


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 2e5 + 5;

int n, q, a[N], b[N];
char s[N];

template <int S>
struct SuffixArray {
    static const int N = S << 1, logN = 20;
    int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N], lg[N], f[N][logN];
    void clear() {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
    }
    void radixSort() {
        for (int i = 1; i <= m; i++) bin[i] = 0;
        for (int i = 1; i <= n; i++) bin[rk[i]]++;
        for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
        for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
    }
    template <class Tp>
    void build(Tp *_a, int _n, int _m) {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
        n = _n, m = _m;
        std::copy(_a + 1, _a + n + 1, a + 1);
        for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
        radixSort();
        for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
            p = 0;
            for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
            for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
            radixSort();
            std::swap(rk, tmp);
            p = rk[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
            }
        }
        for (int i = 1, k = 0; i <= n; i++) {
            k -= (k > 0);
            int j = sa[rk[i] - 1];
            for (; a[i + k] == a[j + k]; k++);
            height[rk[i]] = k;
        }
    }
    void buildST() {
        for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        for (int i = 1; i <= n; i++) f[i][0] = height[i];
        for (int j = 1; (1 << j) <= n; j++) {
            for (int i = 1; i + (1 << j) - 1 <= n; i++) {
                f[i][j] = std::min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
            }
        }
    }
    int lcp(int l, int r) {
        if (l == r) return n - l + 1;
        l = rk[l], r = rk[r];
        if (l > r) std::swap(l, r);
        l++;
        int k = lg[r - l + 1];
        return std::min(f[l][k], f[r - (1 << k) + 1][k]);
    }
};

struct SegmentTree {
    #define lson p << 1
    #define rson p << 1 | 1
    static const int M = N << 2;
    int sum[M];
    long long ans[M];
    bool tag[M];
    void pushup(int p) {
        sum[p] = sum[lson] + sum[rson];
        ans[p] = ans[lson] + ans[rson];
    }
    void pushdown(int p) {
        if (tag[p]) {
            sum[lson] = sum[rson] = ans[lson] = ans[rson] = 0;
            tag[lson] = tag[rson] = true;
            tag[p] = false;
        }
    }
    void erase(int p, int l, int r, int x, int y) {
        if (x > y) return;
        if (x <= l && r <= y) {
            sum[p] = ans[p] = 0, tag[p] = true;
            return;
        }
        pushdown(p);
        int mid = (l + r) >> 1;
        if (y <= mid) {
            erase(lson, l, mid, x, y);
        } else if (x > mid) {
            erase(rson, mid + 1, r, x, y);
        } else {
            erase(lson, l, mid, x, mid);
            erase(rson, mid + 1, r, mid + 1, y);
        }
        pushup(p);
    }
    void modify(int p, int l, int r, int x, int v) {
        if (l == r) {
            sum[p] += v;
            ans[p] += 1LL * l * v;
            return;
        }
        pushdown(p);
        int mid = (l + r) >> 1;
        if (x <= mid) {
            modify(lson, l, mid, x, v);
        } else {
            modify(rson, mid + 1, r, x, v);
        }
        pushup(p);
    }
    int querySum(int p, int l, int r, int x, int y) {
        if (x > y) return 0;
        if (x <= l && r <= y) return sum[p];
        pushdown(p);
        int mid = (l + r) >> 1;
        if (y <= mid) {
            return querySum(lson, l, mid, x, y);
        } else if (x > mid) {
            return querySum(rson, mid + 1, r, x, y);
        } else {
            return querySum(lson, l, mid, x, mid) + querySum(rson, mid + 1, r, mid + 1, y);
        }
    }
    #undef lson
    #undef rson
};

SuffixArray<N> A;
SegmentTree seg;

long long solve(int n, int m) {
    std::sort(a + 1, a + n + 1, [](int a, int b) {
        return A.rk[a] < A.rk[b];
    });
    std::sort(b + 1, b + m + 1, [](int a, int b) {
        return A.rk[a] < A.rk[b];
    });
    long long ans = 0;
    for (int i = 1, j = 1; i <= n; i++) {
        if (i > 1) {
            int x = A.lcp(a[i - 1], a[i]);
            int s = seg.querySum(1, 0, A.n, x + 1, A.n);
            seg.erase(1, 0, A.n, x + 1, A.n);
            seg.modify(1, 0, A.n, x, s);
        }
        for (; j <= m && A.rk[b[j]] <= A.rk[a[i]]; j++) {
            seg.modify(1, 0, A.n, A.lcp(a[i], b[j]), 1);
        }
        ans += seg.ans[1];
    }
    seg.erase(1, 0, A.n, 0, A.n);
    for (int i = 1, j = 1; i <= m; i++) {
        if (i > 1) {
            int x = A.lcp(b[i - 1], b[i]);
            int s = seg.querySum(1, 0, A.n, x + 1, A.n);
            seg.erase(1, 0, A.n, x + 1, A.n);
            seg.modify(1, 0, A.n, x, s);
        }
        for (; j <= n && A.rk[a[j]] < A.rk[b[i]]; j++) {
            seg.modify(1, 0, A.n, A.lcp(a[j], b[i]), 1);
        }
        ans += seg.ans[1];
    }
    seg.erase(1, 0, A.n, 0, A.n);
    return ans;
}
int main() {
    scanf("%d%d%s", &n, &q, s + 1);
    A.build(s, n, 255);
    A.buildST();
    for (int i = 1; i <= q; i++) {
        int k, l;
        scanf("%d%d", &k, &l);
        for (int i = 1; i <= k; i++) scanf("%d", &a[i]);
        for (int i = 1; i <= l; i++) scanf("%d", &b[i]);
        printf("%lld\n", solve(k, l));
    }
    return 0;
}

发表评论