「Codeforces 593D」Happy Tree Party

题目链接:Codeforces 593D

今天是 Bogdan 的生日,他的母亲送给他一棵有 $n​$ 个节点的树,每条边上有一个数字 $x_i​$。有 $m​$ 个客人参加了 Bogdan 的生日排队。第 $i​$ 个客人到达后,他会进行如下 $2​$ 种操作中的恰好一种:

  1. 选择一个数字 $y_i$ 和两个节点 $a_i,b_i$。接下来他沿着 $a_i$ 到 $b_i$ 的最短路径行走。每经过一条边 $j$,他就把当前的数字 $y_i$ 替换成 $\left\lfloor\frac{y_i}{x_j}\right\rfloor$。最后求出 $y_i$ 的值。
  2. 选择一条边 $p_i$,将这条边上的值 $x_{p_i}$ 替换成 $c_i$。其中 $c_i<x_{p_i}$。

由于 Bogdan 非常好客,他希望编写一个程序执行所有的操作,并对每个操作 $1$ 求出其结果 $y_i$ 的值。

数据范围:$2\le n \le 2\times 10^5$,$1\le m\le 2\times 10^5$,$1\le x_i,y_i\le 10^{18}$,$1\le c_i<x_{p_i}$。


Solution

显然用树链剖分是可以做的,但是鉴于其码量大、常数大,以及没有利用本题的性质,再此不赘述了。接下来将介绍一种好写、复杂度优越的做法。

首先我们发现,如果不考虑边权为 $1$ 的边,我们至多只会进行 $\mathcal O(\log y_i)$ 次除法。换言之,如果这棵树上没有权值为 $1$ 的边,我们可以暴力往上跳求出最后的 $y_i$。

又因为修改是单调下降的,于是一条边的边权只可能变小。对于一条权值为 $1$ 的边,我们把他两侧的点缩起来,使用并查集即可实现。但是需要注意合并的方向:一定是深度大的点合并到深度小的点上,否则复杂度是错的。

时间复杂度:$\mathcal O(m\log y_i)$。


Code

#include <cstdio>
#include <algorithm>

const int N = 2e5 + 5, M = 4e5 + 5;

int n, m, tot, lnk[N], sta[M], ter[M], nxt[M], up[N], fa[N], f[N], dep[N];
long long val[M];

void add(int u, int v, long long w) {
    ter[++tot] = v, sta[tot] = u, nxt[tot] = lnk[u], lnk[u] = tot, val[tot] = w;
}
int find(int x) {
    return f[x] == x ? x : f[x] = find(f[x]);
}
void merge(int u, int v) {
    f[find(u)] = find(v);
}
void dfs(int u, int p) {
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p) continue;
        fa[v] = u, up[v] = i, dep[v] = dep[u] + 1;
        if (val[i] == 1) {
            merge(v, u);
            dep[v] = dep[find(v)];
        }
        dfs(v, u);
    }
}
long long query(int u, int v) {
    long long ans = 1;
    for (; (u = find(u)) != (v = find(v)); u = fa[u]) {
        if (dep[u] < dep[v]) {
            std::swap(u, v);
        }
        if (ans > 1e18 / val[up[u]]) {
            return 0;
        }
        ans *= val[up[u]];
    }
    return ans;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++) {
        int u, v;
        long long w;
        scanf("%d%d%lld", &u, &v, &w);
        add(u, v, w), add(v, u, w);
    }
    for (int i = 1; i <= n; i++) {
        f[i] = i;
    }
    dep[1] = 1;
    dfs(1, 0);
    for (int i = 1; i <= m; i++) {
        int opt;
        scanf("%d", &opt);
        if (opt == 1) {
            int u, v;
            long long w;
            scanf("%d%d%lld", &u, &v, &w);
            long long ans = query(u, v);
            printf("%lld\n", ans ? w / ans : 0);
        } else {
            int x;
            long long w;
            scanf("%d%lld", &x, &w);
            x <<= 1;
            val[x - 1] = val[x] = w;
            if (w == 1) {
                int u = find(sta[x]), v = find(ter[x]);
                if (dep[u] > dep[v]) {
                    std::swap(u, v);
                }
                merge(v, u);
                dep[v] = dep[find(v)];
            }
        }
    }
    return 0;
}

发表评论