「Codeforces 280D」K-Maximum Subsequence Sum

题目链接:Codeforces 280D

你有一个长度为 $n$ 的序列 $a_i$,接下来进行 $m$ 次操作,操作分为如下 $2$ 种:

  • 0 i val:将第 $i$ 个数 $a_i$ 修改为 $val$。
  • 1 l r k:你需要在序列 $a_l, a_{l + 1}, \cdots, a_r$ 中找出至多 $k$ 个不相交的子序列,使得他们的和最大。形式化地,你需要找出至多 $k$ 对 $(x_1, y_1), (x_2, y_2), \cdots, (x_t, y_t)$(其中 $l\le x_1\le y_1<x_2\le y_2<\cdots<x_t \le y_t\le r$,$0\le t\le k$),使得 $(a_{x_1} + a_{x_1 + 1} + \cdots + a_{y_1}) + (a_{x_2} + a_{x_ 2 + 1} + \cdots + a_{y_2})+\cdots + (a_{x_t} + a_{x_t + 1} + \cdots + a_{y_t})$ 的值最大。

特别地,你可以选择 $0$ 个子序列,这时和式等于 $0$。

数据范围:$1 \le n, m\le 10 ^ 5$,$\vert a_i, val \vert \le 500$,$1\le k\le 20$,求 $k$ 个子序列和的操作不超过 $10^4$ 个。


Solution

我们很容易建立起费用流的模型。

  • 源点像每个点连流量为 $1$、费用为 $0$ 的边。
  • 每个点向下一个点连流量 $1$,费用为 $a_i$ 的边。
  • 每个点向汇点连流量为 $1$、费用为 $0$ 的边。

可以发现多流一个单位的流量就会多出一个区间。那么问题就变成使用不超过 $k$ 个单位的流量,能够得到的最大费用。

但是直接上费用流肯定 $\text{TLE}$,但是注意到每次増广的贡献是一段区间,我们考虑用线段树维护来模拟费用流

我们把増广的过程转化为线段树能处理的问题:

  • 每次増广相当于询问整个区间的最大子段和
  • 每次更新反向弧的费用相当于将区间取反
  • 増广 $k$ 次相当于在线段树上进行上述过程 $k$ 次。

这样一来我们就可以直接在线段树上操作了!

最后考虑一下代码实现问题!由于我们要实现区间最大子段和及其范围、区间取反,因此需要维护的不止左右端点最大值、区间最大答案……对于每个最大值,还需要记录最小值,这样才能实现区间取反的问题。

为了使代码难度降低,我们可以定义一个 $\text{struct}$ 记录每个值的 $l,r,val$ 等信息,重载运算符来合并区间。

时间复杂度:$\mathcal O(mk\log n)$ 且带有大常数。


Code

#include <cstdio>
#include <algorithm>
#define lson p << 1
#define rson p << 1 | 1

const int N = 1e5 + 5;

int n, m, A[N];

struct Data {
    int l, r, val;
    Data(int _l = 0, int _r = 0, int _val = 0) {
        l = _l, r = _r, val = _val;
    }
    Data operator+(const Data &b) const {
        return Data(l, b.r, val + b.val);
    }
    bool operator<(const Data &b) const {
        return val < b.val;
    }
};

struct Node {
    int rev;
    Data sum, lmx, lmn, rmx, rmn, smx, smn;
    Node() {
        rev = 0;
        sum = lmx = lmn = rmx = rmn = smx = smn = Data();
    }
    void init(int pos, int val) {
        rev = 0;
        sum = lmx = lmn = rmx = rmn = smx = smn = Data(pos, pos, val);
    }
    void reverse() {
        rev ^= 1;
        std::swap(lmx, lmn);
        std::swap(rmx, rmn);
        std::swap(smx, smn);
        sum.val *= -1;
        lmx.val *= -1, lmn.val *= -1;
        rmx.val *= -1, rmn.val *= -1;
        smx.val *= -1, smn.val *= -1;
    }
};

struct Segment {
    Node a[N << 2];
    Node merge(Node x, Node y) {
        Node ans;
        ans.sum = x.sum + y.sum;
        ans.lmx = std::max(x.lmx, x.sum + y.lmx);
        ans.lmn = std::min(x.lmn, x.sum + y.lmn);
        ans.rmx = std::max(y.rmx, x.rmx + y.sum);
        ans.rmn = std::min(y.rmn, x.rmn + y.sum);
        ans.smx = std::max(x.rmx + y.lmx, std::max(x.smx, y.smx));
        ans.smn = std::min(x.rmn + y.lmn, std::min(x.smn, y.smn));
        return ans;
    }
    void pushdown(int p) {
        if (a[p].rev) {
            a[lson].reverse(); a[rson].reverse();
            a[p].rev = 0;
        }
    }
    void build(int p, int l, int r) {
        if (l == r) {
            a[p].init(l, A[l]);
            return;
        }
        int mid = (l + r) >> 1;
        build(lson, l, mid);
        build(rson, mid + 1, r);
        a[p] = merge(a[lson], a[rson]);
    }
    void modify(int p, int l, int r, int x, int v) {
        if (l == r) {
            a[p].init(l, v);
            return;
        }
        pushdown(p);
        int mid = (l + r) >> 1;
        if (x <= mid) {
            modify(lson, l, mid, x, v);
        } else {
            modify(rson, mid + 1, r, x, v);
        }
        a[p] = merge(a[lson], a[rson]);
    }
    void reverse(int p, int l, int r, int x, int y) {
        if (l == x && y == r) {
            a[p].reverse();
            return;
        }
        pushdown(p);
        int mid = (l + r) >> 1;
        if (y <= mid) {
            reverse(lson, l, mid, x, y);
        } else if(x > mid) {
            reverse(rson, mid + 1, r, x, y);
        } else {
            reverse(lson, l, mid, x, mid);
            reverse(rson, mid + 1, r, mid + 1, y);
        }
        a[p] = merge(a[lson], a[rson]);
    }
    Node query(int p, int l, int r, int x, int y) {
        if (l == x && y == r) {
            return a[p];
        }
        pushdown(p);
        int mid = (l + r) >> 1;
        if (y <= mid) {
            return query(lson, l, mid, x, y);
        } else if(x > mid) {
            return query(rson, mid + 1, r, x, y);
        } else {
            return merge(query(lson, l, mid, x, mid), query(rson, mid + 1, r, mid + 1, y));
        }
    }
} seg;

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &A[i]);
    }
    seg.build(1, 1, n);
    scanf("%d", &m);
    for (int i = 1; i <= m; i++) {
        int opt;
        scanf("%d", &opt);
        if (!opt) {
            int x, v;
            scanf("%d%d", &x, &v);
            seg.modify(1, 1, n, x, v);
        } else {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            int ans = 0, tp = 0;
            Data st[25];
            for (int j = 1; j <= k; j++) {
                Node now = seg.query(1, 1, n, l, r);
                if (now.smx.val < 0) {
                    break;
                }
                ans += now.smx.val;
                st[++tp] = now.smx;
                seg.reverse(1, 1, n, now.smx.l, now.smx.r);
            }
            for (int j = 1; j <= tp; j++) {
                seg.reverse(1, 1, n, st[j].l, st[j].r);
            }
            printf("%d\n", ans);
        }
    }
    return 0;
}

发表评论