字母的最长公共子序列

📘 题目描述

给定两个字符串 $s,t$,求它们的最长公共子序列(LCS)的长度。

  • $1 \le |s| \le 5000$
  • $1 \le |t| \le 5 \cdot 10^5$

💡 解题思路

由于 $|s|$ 较小,$|t|$ 较大,无法使用传统的 $O(|s| \cdot |t|)$ 的二维动态规划。因此需调整状态定义和优化转移方式


🔁 状态设计

观察到 LCS 的长度至多为 $|s|$,我们设计:

  • $f_{i,j}$:表示使用 $s$ 的前 $i$ 个字符,构造长度为 $j$ 的公共子序列,在字符串 $t$ 中所能匹配到的最小位置

状态转移思想:用 $s$ 的前 $i$ 个字符构造一个长度为 $j$ 的子序列,所需匹配的最小位置越小越优。


🔄 状态转移

我们有两种选择:

  1. 不选 $s_i$:

    • $f_{i,j} = f_{i-1,j}$
    • 即当前状态继承上一行的最优位置
  2. $s_i$:

    • 从 $f_{i-1, j-1}$ 的位置向后查找字符 $s_i$ 在 $t$ 中的下一个出现位置
    • pos = f[i - 1][j - 1]
    • 在 $t[pos+1 \sim m]$ 中找到第一个等于 $s_i$ 的字符位置 $x$
    • 更新 $f_{i,j} = x$

🔍 如何高效查找 $x$?

预处理 $t$ 中每个字符的出现位置:

vector<int> ch[26];
for (int i = 0; i < m; i++) 
{
    ch[t[i] - 'a'].push_back(i);
}

查找:在 ch[s[i] - 'a'] 中二分查找第一个大于 pos 的值(使用 upper_bound)。

✅ 答案判断

在所有 $j$ 满足 f[n][j] 存在的位置中,最大值 $j$ 即为 LCS 的长度。

⏱️ 复杂度分析

  • 时间复杂度:$O(n^2 \log m)$,其中 $n = |s|, m = |t|$
    • 每个状态转移最多二分一次
  • 空间复杂度:$O(n^2 + m)$
  • $n^2$ 为 DP 数组,$m$ 为位置表存储

实现代码

#include <bits/stdc++.h>
using namespace std;

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int n, m;
    cin >> n >> m;
    string a, b;
    cin >> a >> b;
    a = " " + a, b = " " + b;
    // 预处理每个字母的所有位置,一定是单调递增的
    vec<vec<int>> pos(26);
    for (int i = 1; i <= m; i++) 
        pos[b[i] - 'a'].push_back(i);

    // 初始化所有状态值为 -1 表示不存在
    vec<vec<int>> f(n + 1, vec<int>(n + 1, -1));

    f[0][0] = 0;

    for (int i = 1; i <= n; i++)
    {
        f[i][0] = f[i - 1][0]; // f[i][0] 都是 0 不需要任何字符

        for (int j = 1; j <= i; j++)
        {
            // 不用 a[i] 继承上一个状态
            f[i][j] = f[i - 1][j];

            // 如果 now 等于 -1 就不用求使用 a[i] 的情况了
            int now = f[i - 1][j - 1];
            if (now == -1) continue;

            // 二分在 t[now + 1, m] 中找第一个等于 a[i] 的
            auto it = upper_bound(pos[a[i] - 'a'].begin(), pos[a[i] - 'a'].end(), now);

            // 保证存在
            if (it != pos[a[i] - 'a'].end())
            {
                if (f[i][j] != -1) f[i][j] = min(f[i][j], *it);
                else f[i][j] = *it;
            }

        }
    }
    int ans = -1;
    for (int j = 1; j <= n; j++)
    {
        if (f[n][j] != -1)
        {
            ans = max(ans, j);
        }
    }
    cout << ans;
    return 0;
}