「算法笔记」可持久化线段树

线段树这种数据结构可以可持久化。所谓可持久化,就是可以访问某一个历史版本,我们需要运用不同版本之间的共同性质来降低复杂度。其中主席树是一种可持久化权值线段树,常用于求区间第 $k$ 小值。


可持久化线段树

例题

你需要维护一个长度为 $n$ 的序列 $a_i$,进行 $m$ 种操作,操作类型如下:

  1. 在某个历史版本上修改某一个位置的值。
  2. 访问某个历史版本上的某一个位置的值。

对于每次操作,我们需要生成一个与历史版本完全一样的版本并在这个新的版本上修改。

数据范围:$1\le n,m\le 10^6$,$\vert a_i\vert\le 10^9$。

思路分析

我们最简单的思路就是暴力建立 $m$ 棵线段树,然后每次直接到对应版本查询。这样一来空间复杂度肯定承受不了,直接 $\text{MLE}$。那么我们如何维护历史版本呢?

由于每次修改会使得它到根的路径被修改。线段树的层数为 $\mathcal O(\log n)$ 层,所真正被修改的节点只有 $\mathcal O(\log n)$ 个,我们只需要新建 $\mathcal O(\log n)$ 个节点!这些节点只需要保存从新版本 $i$ 的根节点 $root_i$ 到需要修改的那个叶子节点的路径,对于不在此路径上的左儿子或者右儿子,只要接到原版本对应区间的对应儿子就可以啦!这样我们就可以保证,从对应版本的根节点出发,一定可以访问到这个版本的任何一个节点。

这样一来我们就可以只建立一棵线段树维护所有的版本了。对于每次修改插入一个路径,采用动态开点的方法节约内存。

上述过程的具体实现详见代码。

时间复杂度:$\mathcal O(m\log n)$

代码

#include <cstdio>

const int N = 1e6 + 5, M = 2e7 + 5;

int n, m, idx, a[N], rt[N], val[M], ls[M], rs[M];

void build(int &p, int l, int r) {
    p = ++idx;
    if (l == r) {
        val[p] = a[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(ls[p], l, mid);
    build(rs[p], mid + 1, r);
}
void modify(int &p, int l, int r, int u, int x, int v) {
    p = ++idx, ls[p] = ls[u], rs[p] = rs[u], val[p] = val[u];
    if (l == r) {
        val[p] = v;
        return;
    }
    int mid = (l + r) >> 1;
    if (x <= mid) {
        modify(ls[p], l, mid, ls[u], x, v);
    } else {
        modify(rs[p], mid + 1, r, rs[u], x, v);
    }
}
int query(int p, int l, int r, int x) {
    if (l == r) {
        return val[p];
    }
    int mid = (l + r) >> 1;
    if (x <= mid) {
        return query(ls[p], l, mid, x);
    } else {
        return query(rs[p], mid + 1, r, x);
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    build(rt[0], 1, n);
    for (int i = 1; i <= m; i++) {
        int opt, u, x;
        scanf("%d%d%d", &opt, &u, &x);
        if (opt == 1) {
            int v;
            scanf("%d", &v);
            modify(rt[i], 1, n, rt[u], x, v);
        } else {
            printf("%d\n", query(rt[u], 1, n, x));
            rt[i] = rt[u];
        }
    }
    return 0;
}

权值线段树

我们一般的线段树是用于维护区间的,而权值线段树每个叶子节点维护的是某个元素出现的次数,一条线段代表着这个区间内所有数字的出现次数总和。

可以利用权值线段树查询整体第 $k$ 小值。具体方法如下:

我们先对序列进行离散化。建立一棵权值线段树。从根节点开始查找,设当前左子树的大小为 $ls$,右子树的大小为 $rs$,如果 $k\le ls$,那么第 $k$ 小值是在左子树中;否则该区间的第 $k$ 小值为右子树的第 $k-ls$ 小值。

时间复杂度:$\mathcal O(n \log n)$


静态主席树

例题

给定一个长度为 $n$ 的序列 $a_i$,有 $m$ 次询问,每次询问区间 $[l,r]$ 内的第 $k$ 小值。

数据范围:$1\le n,m\le 2\times 10^5$,$\vert a_i\vert\le 10^9$

思路分析

前缀和思想

我们利用之前提出的权值线段树的方法,查询 $[l,r]$ 就是令 $[1,r]$ 和 $[1,l-1]$ 两棵线段树的 $size$ 相减。由于线段树是完全二叉树,具有结构稳定的性质,所以这 $n$ 棵权值线段树是长得完全一样的,可以相减。所以我们可以建立 $n$ 棵权值线段树,第 $x$ 棵表示 $a_i(i\in[1,x])$ 组成权值线段树。

运用可持久化线段树

但是发现暴力开 $n$ 棵线段树的空间肯定是不行的,考虑如何优化。

我们很容易发现,每次加进来一个新的数字,只有这个权值对应的节点到根的路径会发生变化,这不就是可持久化线段树吗?

所以我们只需要先建立第 $0$ 棵线段树(空树),然后对每个前缀 $i$ 在第 $i-1$ 棵权值线段树的基础上新建 $\mathcal O(\log n)$ 个节点形成第 $i$ 棵权值线段树。

每次查询时,用第 $r$ 棵线段树减去 $l-1$ 棵线段树,具体过程和权值线段树的查找过程相同!

时间复杂度:$\mathcal O((n+m)\log n)$

代码

#include <cstdio>
#include <algorithm>

const int N = 2e5 + 5, M = 4e6 + 5;

int n, m, idx, a[N], b[N], rt[N], sum[M], ls[M], rs[M];

void build(int &rt, int l, int r) {
    rt = ++idx;
    if (l == r) {
        return;
    }
    int mid = (l + r) >> 1;
    build(ls[rt], l, mid);
    build(rs[rt], mid + 1, r);
}
void modify(int &rt, int l, int r, int u, int x, int v) {
    rt = ++idx, ls[rt] = ls[u], rs[rt] = rs[u], sum[rt] = sum[u] + v;
    if (l == r) {
        return;
    }
    int mid = (l + r) >> 1;
    if (x <= mid) {
        modify(ls[rt], l, mid, ls[u], x, v);
    } else {
        modify(rs[rt], mid + 1, r, rs[u], x, v);
    }
}
int query(int l, int r, int L, int R, int k) {
    if (l == r) {
        return l;
    }
    int mid = (l + r) >> 1;
    int sz = sum[ls[R]] - sum[ls[L]];
    if (k <= sz) {
        return query(l, mid, ls[L], ls[R], k);
    } else {
        return query(mid + 1, r, rs[L], rs[R], k - sz);
    }
}
int discretize() {
    for (int i = 1; i <= n; i++) b[i] = a[i];
    std::sort(b + 1, b + n + 1);
    return std::unique(b + 1, b + n + 1) - (b + 1);
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    int sz = discretize();
    build(rt[0], 1, sz);
    for (int i = 1; i <= n; i++) {
        int x = std::lower_bound(b + 1, b + sz + 1, a[i]) - b;
        modify(rt[i], 1, sz, rt[i - 1], x, 1);
    }
    for (int i = 1; i <= m; i++) {
        int l, r, k;
        scanf("%d%d%d", &l, &r, &k);
        int ans = b[query(1, sz, rt[l - 1], rt[r], k)];
        printf("%d\n", ans);
    }
    return 0;
}

习题

1 条评论

  1. realSpongeBob

    %%%Siyuan

发表评论