LeetCode 300 最长上升子序列
题目描述
给定一个无序的整数数组,找到其中最长上升子序列的长度。
说明:
- 可能会有多种最长上升子序列的组合,你只需要输出对应的长度即可。
- 你算法的时间复杂度应该为 $O(n^2)$ 。
进阶:
你能将算法的时间复杂度降低到 $O(n log n)$ 吗?
来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/longest-increasing-subsequence
样例
样例输入
[10,9,2,5,3,7,101,18]
样例输出
4
样例解释
最长的上升子序列是 [2,3,7,101],它的长度是 4。
算法与数据结构
动态规划
贪心
二分查找
题解
LIS 是非常经典的动态规划,本文给出 2 种不同时间复杂度的解法,它们的思想都值得学习。
如果 f(n) 是数组 nums[0 .. n - 1]
中最长的上升子序列,f(n – 1) 是数组 nums[0 .. n - 2]
中最长的上升子序列,依次类推,f(1) 是数组 nums[0]
的最长上升子序列。假设已经有人帮我们解决了 f(1)、f(2)……f(n – 1) 的问题,我们怎么解决 f(n) 呢?
显然我们应该考虑 nums[n - 1]
的值,然后根据 f(1)、f(2)……f(n – 1) 的值来决定 f(n) 的值。可是,要考虑 nums[n - 1]
能不能与现有的 f(x) 构成最长上升子序列,我们必须将它与它的前一个元素 nums[n - 2]
比较大小,而根据上面的定义,我们还不知道 f(x) 到底有哪些元素,只知道到元素 nums[x - 1]
为止的最长上升子序列,不知道这个子序列是不是包含了最后的数。因此,我们需要规定 f(n) 为:如果包含了最后的数,那么最长的子序列应该是什么,即,最后这个数必须包含在子序列当中。
于是,我们终于可以开始梳理动态规划的四要素了。
首先,定义状态:我们定义状态 dp[i]
表示以 nums[i]
结尾的 LIS 的长度。
由状态的定义,可以自然得到最终结果是所有 dp[i]
中最大的。
接下去考虑初始条件:当序列长度为 1 时,自己构成一个序列,所以初始化所有 dp[i] = 1
。
int[] dp = new int[nums.length]; for (int i = 0; i < nums.length; i++) { dp[i] = 1; }
最后是状态转移方程,为了求 dp[i]
,我们可以尝试在所有的 dp[j]
结尾加上 nums[i]
看看是不是可以让长度更长,所以,遍历 j = 0; j < i; j++
,如果 nums[i] > nums[j]
说明当前数字大于待比较序列的最后一个数字,可以加上去,于是 LIS 长度变为 dp[j] + 1
,只需要记录循环过程中出现的最大的 dp[j] + 1
即可。
for (int i = 0; i < nums.length; i++) { for (int j = 0; j < i; j++) { if (nums[i] > nums[j]) { dp[i] = Math.max(dp[i], dp[j] + 1); } } }
最后,找出所有 dp[i]
中的最大值,就是答案。注意,最长子序列不一定是以最后一个元素结尾的,所以不能用 dp[length - 1]
作为答案。
int ans = 0; for (int i = 0; i < nums.length; i++) { ans = Math.max(ans, dp[i]); } return ans;
这样的时间复杂度是平方级别的,有没有办法进一步优化呢?
(我既然这么问了,那显然是可以的)
回顾我们人类怎么找最长上升子序列的:对于一个上升子序列,如果当前最后一个元素越小,那么就越有利于往后添加新的元素,子序列的长度自然也就越长。
这是一种贪心的思想。
我们维护一个 dp[]
数组,定义状态 dp[i]
为长度 i + 1 的所有的上升子序列中,结尾元素最小的那个子序列的结尾元素。例如长度为 3 的子序列有 2 个,分别是 {2, 4, 5}
和 {1, 3, 6}
,那么 dp[2]
就是 5。
这样一来,答案就呼之欲出了,显然 dp.size()
的长度就是答案,因为 存在 这样一个长度为 dp.size()
的子序列,而不存在 dp.size() + 1
的子序列(否则就一定会有 dp[dp.size()]
的值)。
接下来考虑初始状态,dp[0]
的值是可以唯一确定的,只需要遍历整个 nums[]
数组,找到最小的那个就可以了。
但是等一下,这样的话,考察长度为 2、3、4 的时候,岂不是又变成了暴力枚举?因此,下面要运用贪心的思想来做。
我们贪婪地设置 dp[0] = nums[0]
:长度为 1 的最长子序列,就第 1 个元素自己呗。至于后面有没有可能遇到比它更小的?可能。但这事我们留到后面再说,后面在状态转移的时候自然会解决这个问题。
下面我以样例为例,讲一讲是怎么做状态转移的。
[10, 9, 2, 5, 3, 7, 101, 18]
- 我们贪婪地设置
dp[0] = nums[0] = 10
,当前dp[]
为[10]
; - i = 1,看到
nums[1] = 9
,由于 9 比 10 小,所以更利于往后增加序列的长度,所以我们把 10 换成 9,当前dp[]
为[9]
; - i = 2,看到
nums[2] = 2
,由于 2 比 9 小,所以更利于往后增加序列的长度,所以我们把 9 换成 2,当前dp[]
为[2]
; - i = 3,看到
nums[3] = 5
,由于 5 比dp[]
中每一个元素都要大,所以直接在数组尾部加上 5,当前dp[]
为[2, 5]
; - i = 4,看到
nums[4] = 3
,由于 3 比 5 小,所以更利于往后增加序列的长度,所以我们把 5 换成 3,当前dp[]
为[2, 3]
; - i = 5,看到
nums[5] = 7
,由于 7 比dp[]
中每一个元素都要大,所以直接在数组尾部加上 7,当前dp[]
为[2, 3, 7]
; - i = 6,看到
nums[6] = 101
,由于 101 比dp[]
中每一个元素都要大,所以直接在数组尾部加上 101,当前dp[]
为[2, 3, 7, 101]
; - i = 7,看到
nums[7] = 18
,由于 18 比 101 小,所以更利于往后增加序列的长度,所以我们把 101 换成 18,当前dp[]
为[2, 3, 7, 18]
; - 找完了所有的 i,
dp.size() = 4
就是答案。
通过上面的过程,我们可以发现,对于每一个 nums[i]
,我们要做两件事情:
- 如果它比
dp[]
数组中的每一个元素都要大,则将它加入到数组尾部。事实上容易发现,在我们的维护规则下,dp[]
是单调递增的,因此只需要和当前数组的最后一个数字比大小就可以了。
if (dp.get(dp.size() - 1) < nums[i]) { dp.add(nums[i]); }
- 如果它不是比
dp[]
数组中每一个元素都要大,那么找到dp[]
数组中第一个比它(nums[i]
)大的元素dp[x]
,用nums[i]
替换dp[x]
。
这个过程可以用二分查找,因此查找的时间复杂度是 $O(n log n)$,而除此以外我们只对数组进行了一次遍历,是线性的动态规划,所以总的时间复杂度是 $O(n log n)$。
在实现二分查找的时候,Java 可以使用 Arrays.binarySearch()
或者 Collections.binarySearch()
方法。这个方法接收的参数是 (list, key)
,表示在 list
数组中找 key
这个数,也可以是增加了查找范围的 (list, startPosition, endPosition, key)
,两者的区别只在不指名范围时是整个数组的范围。
需要特别说明的是这个函数的返回值:
- 如果找到,则返回该元素第一次出现的下标;
- 如果没找到,则返回
(-(insertion point) - 1)
,其中insertion point
表示这个值应该被插入在什么位置才能使得整个数组依旧有序,即,数组中第一个比该值大的元素的下标,或者数组的长度(如果所有元素都小于该值)。
在这样的定义下,这个方法保证了,当且仅当 key
被找到时,返回值大于等于 0。(再次强调,“当且仅当”,即,若 key
被找到,则返回值大于等于 0,若返回值大于等于 0,则 key
被找到)。
所以,我们可以利用返回值的正负来进行不同的操作。
int searchResult = Collections.binarySearch(dp, nums[i]); if (searchResult < 0) { int insertionPoint = -(searchResult + 1); dp.set(insertionPoint, nums[i]); }
对于 C++,可以使用 lower_bound()
。
dp[lower_bound(dp, dp + pos + 1, a[i]) - dp] = a[i];
注意 C++ 的这个函数,返回值只有正数,因此不能直接知道是不是存在 key
,它可能只是返回的是这个值应该存在的地方。不过好在,就算存在,那用相同的值替换,也没关系。
完整代码
O(n^2) 时间复杂度
int[] dp = new int[nums.length]; for (int i = 0; i < nums.length; i++) { dp[i] = 1; } for (int i = 0; i < nums.length; i++) { for (int j = 0; j < i; j++) { if (nums[i] > nums[j]) { dp[i] = Math.max(dp[i], dp[j] + 1); } } } int ans = 0; for (int i = 0; i < nums.length; i++) { ans = Math.max(ans, dp[i]); } return ans;
O(n log n) 时间复杂度
if (nums.length == 0) { return 0; } ArrayList<Integer> dp = new ArrayList<>(); dp.add(nums[0]); for (int i = 1; i < nums.length; i++) { if (dp.get(dp.size() - 1) < nums[i]) { dp.add(nums[i]); } else { int searchResult = Collections.binarySearch(dp, nums[i]); if (searchResult < 0) { int insertionPoint = -(searchResult + 1); dp.set(insertionPoint, nums[i]); } } } return dp.size();