题目描述
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log(m+n)).
输入与输出
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
}
};
样例
- nums1 =
[1, 3]
, nums2 =[2]
, the median is2.0
. - nums1 =
[1, 2]
, nums2 =[3, 4]
, the median is (2 + 3) / 2 =2.5
.
题解与分析
解法一
一个直观的解法是归并两个有序数组,直接找出中位数。
C++ 代码如下:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int size1 = nums1.size();
int size2 = nums2.size();
int mid = (size1 + size2) / 2 + 1;
int *arr = new int[mid];
int i = 0, j = 0, k = 0;
// 归并过程
while (k < mid) {
if (i >= size1)
arr[k++] = nums2[j++];
else if (j >= size2)
arr[k++] = nums1[i++];
else if (nums1[i] < nums2[j])
arr[k++] = nums1[i++];
else
arr[k++] = nums2[j++];
}
return (size1 + size2) % 2 == 0 ? (arr[mid - 1] + arr[mid - 2]) / 2.0 : arr[mid - 1];
}
};
该解法的时间复杂度为 O(m + n)。
解法二
解法一的时间复杂度并没有达到题目的要求,需要找到更加高效的解法。从题目中要求的 O(log(m + n)) 复杂度联想到二分查找算法。显然我们无法直接查找中位数,需要分析中位数的数学性质来找到一个方便二分查找的条件。
设数组 nums1
的长度为 m
,数组 nums2
的长度为 n
,这里不妨设 m <= n
。设整数 i, j
满足 0 <= i <= m && 0 <= j <= n
。i,j
把数组 nums1,nums2
分成两部分:
left | right |
---|---|
nums1[0] ... nums1[i - 1] |
nums1[i] ... nums1[m - 1] |
nums2[0] ... nums2[j - 1] |
nums2[j] ... nums2[n - 1] |
如果满足 (i == 0 || j == n || nums1[i - 1] <= nums2[j]) && (i == m || j == 0 || nums2[j - 1] <= nums1[i])
,上表中左侧数字均小于右侧数字。如果进一步要求左侧数字的总数为 (m + n) / 2
,则中间轴附近的数字就是中位数。
由上述讨论结果可知,i + j = (m + n) / 2
。又由于 m <= n
,上述条件可以简化为 (i == 0 || nums1[i - 1] <= nums2[j]) && (i == m || nums2[j - 1] <= nums1[i])
。至此,我们找到一个可以用来二分查找的性质。
下面讨论具体的二分查找策略:查找变量为 i
,查找区间为 [0, m]
。如果不满足 nums1[i - 1] <= nums2[j]
,由上表可以看出 i
过大,需要降低上限至 i - 1
。如果不满足 nums2[j - 1] <= nums2[i]
,由上表可以看出 i
过小,需要提高下限至 i + 1
。直至满足条件退出。
查找结束后,如果两个数组长度和为奇数,中位数是上表右侧数字中的最小值,换言之,min(nums1[i], nums2[j])
。如果是偶数,中位数是上表左侧数字中的最大值与右侧数字中的最小值的平均值。
C++ 代码如下:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int size1 = nums1.size();
int size2 = nums2.size();
vector<int> &A = size1 <= size2 ? nums1 : nums2;
vector<int> &B = size1 <= size2 ? nums2 : nums1;
int m = min(size1, size2);
int n = max(size1, size2);
int imin = 0, imax = m, half = (m + n) / 2;
int maxLeft, minRight;
while (imin <= imax) {
int i = (imin + imax) / 2;
int j = half - i;
if (i < m && B[j - 1] > A[i])
imin = i + 1;
else if (i > 0 && A[i - 1] > B[j])
imax = i - 1;
else {
if (i == m)
minRight = B[j];
else if (j == n)
minRight = A[i];
else
minRight = A[i] < B[j] ? A[i] : B[j];
if ((m + n) % 2 == 1)
return minRight;
if (i == 0)
maxLeft = B[j - 1];
else if (j == 0)
maxLeft = A[i - 1];
else
maxLeft = A[i - 1] > B[j - 1] ? A[i - 1] : B[j - 1];
return (maxLeft + minRight) / 2.0;
}
}
}
};
该解法的时间复杂度为 O(log(m + n)。