merage sort
Divide-and-conquer
merge sort 的核心理念是 Divide-and-conquer ,这个范式的核心是把问题分割成跟原问题相似的子问题,然后,递归的解决这些子问题,最后把这些子问题的结论合并得到原始问题的答案。Divide-and-conquer 分三步:
- Divide 把问题分割成跟原来的问题一致但是规模变小了的子问题。
- Conquer 递归的解决子问题。如果问题足够小了,直接解决子问题。
- Combine 把子问题的解决方案合并的到原问题的解决方案。
如图所示:
进一步扩展成更多的递归步骤:
Merge Sort
Merge Sort 采用的就是Divide-and-conquer ,对于一个数组a[p....r]有以下三步:
- Divide 找到p和r的中点q;
- Conquer 递归的对每次生成的子数组进行排序;
- Combine 把两个排序好的子数组合并成一个排序好的数组。
举一个具体例子:
对于一个数组 array[0..7] [14, 7, 3, 12, 9, 11, 6, 2] .
- 在divide 阶段,我们计算出来一个中值q=3
- 在conquer 阶段,我们分别对两个子数组array[0..3], [14, 7, 3, 12] 和 array[4..7], [9, 11, 6, 2] 进行排序。当我们完成conquer ,两个数组都是有序的分别是 [3, 7, 12, 14] 和 [2, 6, 9, 11],最后我们得到完整的数组[3, 7, 12, 14, 2, 6, 9, 11]
- 最后在combine阶段,我们合并(merge) 两个数组,得到最终的有序数组[2, 3, 6, 7, 9, 11, 12, 14]
我们用图形来演示一下这个过程:
再来两个动图演示一下:
Java 实现
public void mergeSort(int[] A, int start, int end, int[] temp) {
if(start >= end) {
return;
}
int left = start;
int right = end;
int mid = (start + end) / 2;
mergeSort(A, start, mid, temp);
mergeSort(A, mid + 1, end, temp);
merge(A, start, mid, end, temp)
}
public void merge(int[] A, int start, int mid, int end, int[] temp) {
int left = start;
int right = mid + 1;
int index = start;
while(left <= mid && right <= end) {
if(A[left] < A[right]) {
temp[index++] = A[left++];
} else {
temp[index++] = A[right++];
}
}
while(left <= mid) {
temp[index++] = A[left++];
}
while (right <= end) {
temp[index++] = A[right++];
}
for(index = start; index <= end; index++) {
A[index] = temp[index];
}
}
LeetCode 衍生问题
算法前面说的比较清楚,直接说两个比较难的问题:
Merge k sorted linked lists and return it as one sorted list. Analyze and describe its complexity.
public class ListNode {
int val;
ListNode next;
ListNode(int x) {val = x;}
}
解法1:
ListNode dummy = new ListNode(0);
public ListNode mergeKLists(ListNode[] list) {
if(lists.length == 0) return null;
int i = 0;
int j = lists.length - 1;
while( j != 0 ) {
while(i < j) {
lists[i] = mergeTwo(lists[i++], lists[j--]);
}
i = 0;
}
return lists[0];
}
public ListNode mergeTwo(ListNode node1, ListNode node2) {
ListNode head = dummy;
while(node1 != null && node2 != null) {
if(node1.val < node2.val) {
head.next = node1;
node1 = node1.next;
}
head = head.next;
}
if(node1 != null) {
head.next = node1;
}
if(node2 != null) {
head.next = node2;
}
return dummy.next;
}
解法2,我们借助java的优先级队列(PriorityQueue)来解这个问题:
public ListNode mergeKLists(ListNode[] lists) {
PriorityQueue<ListNode> heap = new PriorityQueue<>(new Comparator<ListNode>() {
@Override
public int compare(ListNode o1, ListNode o2) {
return o1.val - o2.val;
}
});
for(ListNode n : lists) {
if(n != null) {
heap.add(n);
}
}
ListNode head = heap.peek();
while(!heap.isEmpty()) {
ListNode min = heap.poll();
if (min.next != null) {
heap.add(min.next);
}
min.next = heap.peek();
}
return head;
}
优先级队列实现了一种通过堆排序来解决问题的方案。
第二个问题 Merge Intervals:
Given a collection of intervals, merge all overlapping intervals.
For example,
Given
[1,3],[2,6],[8,10],[15,18]
,return
[1,6],[8,10],[15,18]
.
数据结构如下:
public class Interval {
int start;
int end;
Interval() {start = 0; end = 0;}
Interval(int s, int e) {start = s; end = e;}
};
解法1:
public List<Interval> merge(List<Interval> intervals) {
if (intervals == null || intervals.size() <=1 ) return intervals;
List<Interval> merged = new ArrayList<>();
//对输入进行一下排序
Collections.sort(intervals, new Comparator<Interval>() {
@Override
public int compare(Interval o1, Interval o2) {
if(o1.start != o2.start) {
return o1.start - o2.start;
}
return o1.end - o2.end;
}
});
int startMin = intervals.get(0).start;
int endMax = intervals.get(0).end;
for (int i = 1; i < intervals.size(); i++) {
int start = intervals.get(i).start;
int end = intervals.get(i).end;
if (start > endMax) {
merged.add(new Interval(startMin, endMax));
//完成一个节点的添加;
startMin = start;
endMax = end;
} else {
//吞掉一个节点;
endMax = Math.max(end, endMax);
}
}
merged.add(new Interval(startMin, endMax));
return merged;
}
解法2(原地排序,空间占用率低):
public List<Interval> merge(List<Interval> intervals) {
if (intervals == null || intervals.size() < 2) {
return intervals;
}
intervals.sort(new Comparator<Interval>() {
@Override
public int compare(Interval o1, Interval o2) {
return o1.start - o2.start;
}
});
int length = 1;
for(int i = 1; i < intervals.size(); i++) {
if(intervals.get(length - 1).end < intervals.get(i).start) {
intervals.set(length++, intervals.get(i));
} else {
intervals.get(length - 1).end = Math.max(intervals.get(length - 1).end, intervals.get(i).end);
}
}
return intervals.subList(0, length);
}