跳转至

线段树合并

线段树的合并与分裂是线段树的常用技巧,常见于权值线段树维护可重集的场景。

例如,树上某些结点处有若干操作,如果需要自下而上地将子节点信息传递给亲节点,而单个结点处的信息又方便用线段树维护时,就可以应用线段树合并的技巧控制整体的复杂度。


合并过程

顾名思义,线段树合并是指建立一棵新的线段树,这棵线段树的每个节点都是两棵原线段树对应节点合并后的结果。它常常被用于维护树上或是图上的信息。

显然,我们不可能真的每次建满一颗新的线段树,因此我们需要使用动态开点线段树。

当我们合并两棵树的时候,我们把重合的节点的值相加,而把不重合的节点保持原样。

线段树合并的过程本质上相当暴力:

  • 假设两颗线段树为 $A$ 和 $B$,我们从 $1$ 号节点开始递归合并。

  • 递归到某个节点时,如果 $A$ 树或者 $B$ 树上的对应节点为空,直接返回另一个树上对应节点,这里运用了动态开点线段树的特性。

  • 如果递归到叶子节点,我们合并两棵树上的对应节点。

  • 最后,根据子节点更新当前节点并且返回


第一种合并方式

第一种我称为新建式,比较直观,我们用一个新的节点存合并的结果。这样会新生成重合节点那么多的新节点。

代码实现如下:

int merge(int a, int b, int l = 1, int r = n) 
{
    if (!a || !b) return a + b; // 如果有一个为空,就返回不为空的;如果都为空就返回空
    int c = ++cnt;
    if (l == r) 
    {
        T[c].v = T[a].v + T[b].v;
        return c; // 新的根 c
    }
    int mid = l + r >> 1;
    T[c].ls = merge(T[a].ls, T[b].ls, l, mid);
    T[c].rs = merge(T[a].rs, T[b].rs, mid + 1, r);
    pushup(c);
    return c;
}

第二种合并方式

第二种我称为挂靠式,也就是把第二棵树直接合并到第一棵树上。这样比较省空间(不要小看这点空间,这个算法挺吃空间的),缺点是会丢失合并前树的信息。

代码实现如下:

int merge(int a, int b, int l = 1, int r = n) 
{
    if (!a || !b) return a + b; // 如果有一个为空,就返回不为空的;如果都为空就返回空
    if (l == r) 
    {
        T[a].v = T[a].v + T[b].v;
        return a; 
    }
    int mid = l + r >> 1;
    T[a].ls = merge(T[a].ls, T[b].ls, l, mid);
    T[a].rs = merge(T[a].rs, T[b].rs, mid + 1, r);
    pushup(a);
    return c;
}

复杂度分析

显然,对于两颗满的线段树,合并操作的复杂度是 $O(n\log n)$ 的。但实际情况下使用的常常是权值线段树,总点数和 $n$ 的规模相差并不大。并且合并时一般不会重复地合并某个线段树,所以我们最终增加的点数大致是 $n\log n$ 级别的。这样,总的复杂度就是 $O(n\log n)$ 级别的。


例题一

luogu P4556 [Vani 有约会] 雨天的尾巴/【模板】线段树合并
题意 有一棵 $n$ 个节点的树,有 $m$ 次操作。每次操作令路径 $x\to y$ 上所有节点增加一个数字 $z$。最后输出每个节点上哪个数字出现的次数最多,如果有多个数字出现次数最多输出值最小的那个。
解题思路
  • 首先使用树上点差分可以将每次操作转化为一个点到根的路径上数字 \(z\) 都出现 \(k\) 次。(\(k \in \{1, -1\}\)
  • 为每个节点开一棵权值线段树维护每个数字的出现次数,以及顺带维护出现次数最多的点的编号。
  • 对于一棵以 \(u\) 为根的子树来讲,就是把其所有儿子的线段树合并到一起。
  • 最后通过一次 DFS 从上至下,进行线段树合并。

空间复杂度:关键就是 新建节点数

  • 每次 modify 新建的节点数是 \(O(\log V)\)
  • 总共有 \(O(m \log V)\) 个节点
  • 合并操作不会新建节点,只是复用指针,所以不额外增加量级

因此总节点数上界:

\[ \text{Nodes} \leq 4m \log V \]

(因为每次操作有 4 次 modify

参考代码
#include <bits/stdc++.h>
#define ls t[p].l
#define rs t[p].r
using namespace std;
constexpr int N = 1e5 + 5;
int n, m, dep[N], f[N][17], tot, root[N * 50], ans[N];
vector<int> e[N];
struct node
{
    int l, r;
    int cnt, res;
} t[N * 50];
void dfs(int u, int fa)
{
    dep[u] = dep[fa] + 1;
    for (auto v : e[u])
    {
        if (v == fa) continue;
        f[v][0] = u;
        for (int j = 1; j <= 16; j++)
            f[v][j] = f[f[v][j - 1]][j - 1];
        dfs(v, u);
    }
}
int lca(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    int h = dep[u] - dep[v];
    for (int i = 16; i >= 0; i--)
        if (h & (1 << i))
            u = f[u][i];
    if (u == v) return u;
    for (int i = 16; i >= 0; i--)
        if (f[u][i] != f[v][i])
            u = f[u][i], v = f[v][i];
    return f[u][0];
}
void push_up(int p)
{
    if (t[ls].cnt < t[rs].cnt)
    {
        t[p].cnt = t[rs].cnt;
        t[p].res = t[rs].res;
    }
    else
    {
        t[p].cnt = t[ls].cnt;
        t[p].res = t[ls].res;
    }
}
void modify(int &p, int l, int r, int x, int k)
{
    if (!p) p = ++tot;
    if (l == r)
    {
        t[p].cnt += k;
        t[p].res = x;
        return ;
    }
    int mid = l + r >> 1;
    if (x <= mid) modify(ls, l, mid, x, k);
    else modify(rs, mid + 1, r, x, k);
    push_up(p);
}
int merge(int a, int b, int l, int r)
{
    if (!a || !b) return a + b;
    if (l == r)
    {
        t[a].cnt += t[b].cnt;
        return a;
    }
    int mid = l + r >> 1;
    t[a].l = merge(t[a].l, t[b].l, l, mid);
    t[a].r = merge(t[a].r, t[b].r, mid + 1, r);
    push_up(a);
    return a;
}
void calc(int u, int fa)
{
    for (auto v : e[u])
    {
        if (v == fa) continue;
        calc(v, u);
        root[u] = merge(root[u], root[v], 1, 100000);
    }
    ans[u] = t[root[u]].res;
    if (t[root[u]].cnt == 0) ans[u] = 0;
}
int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1, 0);

    while (m--)
    {
        int x, y, z;
        cin >> x >> y >> z;
        int L = lca(x, y), fL = f[L][0];
        modify(root[x], 1, 100000, z, 1);
        modify(root[y], 1, 100000, z, 1);
        modify(root[fL], 1, 100000, z, -1);
        modify(root[L], 1, 100000, z, -1);
    }
    calc(1, 0);
    for (int i = 1; i <= n; i++)
        cout << ans[i] << "\n";
    return 0;
}

例题二

[HNOI2012] 永无乡
解题思路
  • 使用并查集维护连通性。
  • 为每个节点开一棵权值线段树,线段树上每个节点对应当前权值区间数字的出现次数总和。
  • 合并时,即合并两个集合对应的线段树。执行线段树合并即可。

查询时:

  • 使用线段树查询第 \(k\) 小对应的重要度记作为 \(x\),通过映射找到重要度为 \(x\) 对应的节点编号
  • 注意并查集合并方向与线段树合并方向保持一致
参考代码
#include <bits/stdc++.h>
#define ls t[p].l
#define rs t[p].r
using namespace std;
constexpr int N = 1e5 + 5;
int tot, n, m, root[N * 50], p[N], q;
int find(int x)
{
    return x == p[x] ? x : p[x] = find(p[x]);
}
struct sgt
{
    int l, r, cnt;
} t[N * 50];
void push_up(int p)
{
    t[p].cnt = t[ls].cnt + t[rs].cnt;
}
void modify(int &p, int l, int r, int x)
{
    if (!p) p = ++tot;
    if (l == r)
    {
        t[p].cnt++;
        return ;
    }
    int mid = l + r >> 1;
    if (x <= mid) modify(ls, l, mid, x);
    else modify(rs, mid + 1, r, x);
    push_up(p);
}
int merge(int a, int b, int l = 1, int r = n)
{
    if (!a || !b) return a + b;
    if (l == r)
    {
        t[a].cnt += t[b].cnt;
        return a;
    }
    int mid = l + r >> 1;
    t[a].l = merge(t[a].l, t[b].l, l, mid);
    t[a].r = merge(t[a].r, t[b].r, mid + 1, r);
    push_up(a);
    return a;
}
int query(int p, int l, int r, int k)
{
    if (l == r) return l;
    int mid = l + r >> 1;
    if (t[ls].cnt >= k) return query(ls, l, mid, k);
    else return query(rs, mid + 1, r, k - t[ls].cnt);
}
int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m;
    vector<int> ans(n + 1);
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        ans[x] = i;
        p[i] = i;
        modify(root[i], 1, n, x);
    }

    while (m--)
    {
        int u, v;
        cin >> u >> v;
        u = find(u), v = find(v);
        p[u] = v;
        root[v] = merge(root[v], root[u]);
    }
    cin >> q;
    while (q--)
    {
        char op;
        cin >> op;
        int x, y;
        cin >> x >> y;
        if (op == 'Q')
        {
            x = find(x);
            if (t[root[x]].cnt < y)
            {
                cout << "-1\n";
                continue;
            }
            int id = query(root[x], 1, n, y);
            cout << ans[id] << "\n";
        }
        else
        {
            x = find(x), y = find(y);
            p[x] = y;
            root[y] = merge(root[y], root[x]);
        }
    }
    return 0;
}