LeetCode 300 最长上升子序列
题目描述
给定一个无序的整数数组,找到其中最长上升子序列的长度。
说明:
- 可能会有多种最长上升子序列的组合,你只需要输出对应的长度即可。
- 你算法的时间复杂度应该为 $O(n^2)$ 。
进阶:
你能将算法的时间复杂度降低到 $O(n log n)$ 吗?
来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/longest-increasing-subsequence
样例
样例输入
[10,9,2,5,3,7,101,18]
样例输出
4
样例解释
最长的上升子序列是 [2,3,7,101],它的长度是 4。
算法与数据结构
动态规划
贪心
二分查找
题解
LIS 是非常经典的动态规划,本文给出 2 种不同时间复杂度的解法,它们的思想都值得学习。
如果 f(n) 是数组 nums[0 .. n - 1] 中最长的上升子序列,f(n – 1) 是数组 nums[0 .. n - 2] 中最长的上升子序列,依次类推,f(1) 是数组 nums[0] 的最长上升子序列。假设已经有人帮我们解决了 f(1)、f(2)……f(n – 1) 的问题,我们怎么解决 f(n) 呢?
显然我们应该考虑 nums[n - 1] 的值,然后根据 f(1)、f(2)……f(n – 1) 的值来决定 f(n) 的值。可是,要考虑 nums[n - 1] 能不能与现有的 f(x) 构成最长上升子序列,我们必须将它与它的前一个元素 nums[n - 2] 比较大小,而根据上面的定义,我们还不知道 f(x) 到底有哪些元素,只知道到元素 nums[x - 1] 为止的最长上升子序列,不知道这个子序列是不是包含了最后的数。因此,我们需要规定 f(n) 为:如果包含了最后的数,那么最长的子序列应该是什么,即,最后这个数必须包含在子序列当中。
于是,我们终于可以开始梳理动态规划的四要素了。
首先,定义状态:我们定义状态 dp[i] 表示以 nums[i] 结尾的 LIS 的长度。
由状态的定义,可以自然得到最终结果是所有 dp[i] 中最大的。
接下去考虑初始条件:当序列长度为 1 时,自己构成一个序列,所以初始化所有 dp[i] = 1。
int[] dp = new int[nums.length];
for (int i = 0; i < nums.length; i++) {
dp[i] = 1;
}
最后是状态转移方程,为了求 dp[i],我们可以尝试在所有的 dp[j] 结尾加上 nums[i] 看看是不是可以让长度更长,所以,遍历 j = 0; j < i; j++,如果 nums[i] > nums[j] 说明当前数字大于待比较序列的最后一个数字,可以加上去,于是 LIS 长度变为 dp[j] + 1,只需要记录循环过程中出现的最大的 dp[j] + 1 即可。
for (int i = 0; i < nums.length; i++) {
for (int j = 0; j < i; j++) {
if (nums[i] > nums[j]) {
dp[i] = Math.max(dp[i], dp[j] + 1);
}
}
}
最后,找出所有 dp[i] 中的最大值,就是答案。注意,最长子序列不一定是以最后一个元素结尾的,所以不能用 dp[length - 1] 作为答案。
int ans = 0;
for (int i = 0; i < nums.length; i++) {
ans = Math.max(ans, dp[i]);
}
return ans;
这样的时间复杂度是平方级别的,有没有办法进一步优化呢?
(我既然这么问了,那显然是可以的)
回顾我们人类怎么找最长上升子序列的:对于一个上升子序列,如果当前最后一个元素越小,那么就越有利于往后添加新的元素,子序列的长度自然也就越长。
这是一种贪心的思想。
我们维护一个 dp[] 数组,定义状态 dp[i] 为长度 i + 1 的所有的上升子序列中,结尾元素最小的那个子序列的结尾元素。例如长度为 3 的子序列有 2 个,分别是 {2, 4, 5} 和 {1, 3, 6},那么 dp[2] 就是 5。
这样一来,答案就呼之欲出了,显然 dp.size() 的长度就是答案,因为 存在 这样一个长度为 dp.size() 的子序列,而不存在 dp.size() + 1 的子序列(否则就一定会有 dp[dp.size()] 的值)。
接下来考虑初始状态,dp[0] 的值是可以唯一确定的,只需要遍历整个 nums[] 数组,找到最小的那个就可以了。
但是等一下,这样的话,考察长度为 2、3、4 的时候,岂不是又变成了暴力枚举?因此,下面要运用贪心的思想来做。
我们贪婪地设置 dp[0] = nums[0]:长度为 1 的最长子序列,就第 1 个元素自己呗。至于后面有没有可能遇到比它更小的?可能。但这事我们留到后面再说,后面在状态转移的时候自然会解决这个问题。
下面我以样例为例,讲一讲是怎么做状态转移的。
[10, 9, 2, 5, 3, 7, 101, 18]
- 我们贪婪地设置
dp[0] = nums[0] = 10,当前dp[]为[10]; - i = 1,看到
nums[1] = 9,由于 9 比 10 小,所以更利于往后增加序列的长度,所以我们把 10 换成 9,当前dp[]为[9]; - i = 2,看到
nums[2] = 2,由于 2 比 9 小,所以更利于往后增加序列的长度,所以我们把 9 换成 2,当前dp[]为[2]; - i = 3,看到
nums[3] = 5,由于 5 比dp[]中每一个元素都要大,所以直接在数组尾部加上 5,当前dp[]为[2, 5]; - i = 4,看到
nums[4] = 3,由于 3 比 5 小,所以更利于往后增加序列的长度,所以我们把 5 换成 3,当前dp[]为[2, 3]; - i = 5,看到
nums[5] = 7,由于 7 比dp[]中每一个元素都要大,所以直接在数组尾部加上 7,当前dp[]为[2, 3, 7]; - i = 6,看到
nums[6] = 101,由于 101 比dp[]中每一个元素都要大,所以直接在数组尾部加上 101,当前dp[]为[2, 3, 7, 101]; - i = 7,看到
nums[7] = 18,由于 18 比 101 小,所以更利于往后增加序列的长度,所以我们把 101 换成 18,当前dp[]为[2, 3, 7, 18]; - 找完了所有的 i,
dp.size() = 4就是答案。
通过上面的过程,我们可以发现,对于每一个 nums[i],我们要做两件事情:
- 如果它比
dp[]数组中的每一个元素都要大,则将它加入到数组尾部。事实上容易发现,在我们的维护规则下,dp[]是单调递增的,因此只需要和当前数组的最后一个数字比大小就可以了。
if (dp.get(dp.size() - 1) < nums[i]) {
dp.add(nums[i]);
}
- 如果它不是比
dp[]数组中每一个元素都要大,那么找到dp[]数组中第一个比它(nums[i])大的元素dp[x],用nums[i]替换dp[x]。
这个过程可以用二分查找,因此查找的时间复杂度是 $O(n log n)$,而除此以外我们只对数组进行了一次遍历,是线性的动态规划,所以总的时间复杂度是 $O(n log n)$。
在实现二分查找的时候,Java 可以使用 Arrays.binarySearch() 或者 Collections.binarySearch() 方法。这个方法接收的参数是 (list, key),表示在 list 数组中找 key 这个数,也可以是增加了查找范围的 (list, startPosition, endPosition, key),两者的区别只在不指名范围时是整个数组的范围。
需要特别说明的是这个函数的返回值:
- 如果找到,则返回该元素第一次出现的下标;
- 如果没找到,则返回
(-(insertion point) - 1),其中insertion point表示这个值应该被插入在什么位置才能使得整个数组依旧有序,即,数组中第一个比该值大的元素的下标,或者数组的长度(如果所有元素都小于该值)。
在这样的定义下,这个方法保证了,当且仅当 key 被找到时,返回值大于等于 0。(再次强调,“当且仅当”,即,若 key 被找到,则返回值大于等于 0,若返回值大于等于 0,则 key 被找到)。
所以,我们可以利用返回值的正负来进行不同的操作。
int searchResult = Collections.binarySearch(dp, nums[i]);
if (searchResult < 0) {
int insertionPoint = -(searchResult + 1);
dp.set(insertionPoint, nums[i]);
}
对于 C++,可以使用 lower_bound()。
dp[lower_bound(dp, dp + pos + 1, a[i]) - dp] = a[i];
注意 C++ 的这个函数,返回值只有正数,因此不能直接知道是不是存在 key,它可能只是返回的是这个值应该存在的地方。不过好在,就算存在,那用相同的值替换,也没关系。
完整代码
O(n^2) 时间复杂度
int[] dp = new int[nums.length];
for (int i = 0; i < nums.length; i++) {
dp[i] = 1;
}
for (int i = 0; i < nums.length; i++) {
for (int j = 0; j < i; j++) {
if (nums[i] > nums[j]) {
dp[i] = Math.max(dp[i], dp[j] + 1);
}
}
}
int ans = 0;
for (int i = 0; i < nums.length; i++) {
ans = Math.max(ans, dp[i]);
}
return ans;
O(n log n) 时间复杂度
if (nums.length == 0) {
return 0;
}
ArrayList<Integer> dp = new ArrayList<>();
dp.add(nums[0]);
for (int i = 1; i < nums.length; i++) {
if (dp.get(dp.size() - 1) < nums[i]) {
dp.add(nums[i]);
} else {
int searchResult = Collections.binarySearch(dp, nums[i]);
if (searchResult < 0) {
int insertionPoint = -(searchResult + 1);
dp.set(insertionPoint, nums[i]);
}
}
}
return dp.size();
