数据结构与算法--二叉查找树
上节中学习了基于链表的顺序查找和有序数组的二分查找,其中前者在插入删除时更有优势,而后者在查找上效率更高。能不能将这两个优点结合起来呢?这就是接下来要学的二叉查找树。
首先,二叉查找树是一棵二叉树,每个结点都只有左后两个链接或者称为子结点、子树。每个结点的键都大于其左子树任意结点的键,同时小于其右子树任意结点的键。
如图结点E的左子树都比E小,其右子树都比E要大。
在经典的二叉树实现中,我们会增加一个结点计数器,用来表示以此结点为根的子树中的结点总数。一棵二叉查找树表示唯一的一组有序键,但是一组有序键可以用多棵二叉查找树表示。如下
它们表示的都是同一组有序键:ACEHMRSX,细心的你可能已经发现这其实就是二叉树中序遍历的结果。而且由于加入了结点计数器,对于每个结点都有
size(x) = size(x.left) + size(x.right) + 1
上式+1
是因为要算上父结点x。
基本实现
我们先来定义二叉查找树。
package Chap8;
public class BST<Key extends Comparable<Key>, Value> {
private Node root;
private class Node {
private Key key;
private Value value;
private Node left, right;
private int N; // 结点计数器,以该结点为根的子树结点总数
public Node(Key key, Value value, int N) {
this.key = key;
this.value = value;
this.N = N;
}
}
public int size() {
return size(root);
}
private int size(Node node) {
if (node == null) {
return 0;
} else {
return node.N;
}
}
}
查找
采用递归算法比较容易理解:从根结点开始,如果树是空的,则返回null表示查找未命中;如果被查找的键和当前根结点相等,查找命中,否则就递归地在适当的子树里继续查找——具体来说就是如果被查找的键小于根结点的键就在其左子树继续查找;如果大于根结点就在其右子树继续查找。根据上面表述,已经可以写出get(Key key)
方法了。
public Value get(Key key) {
return get(root, key);
}
private Value get(Node node, Key key) {
if (node == null) {
return null;
}
// 和当前结点比较
int cmp = key.compareTo(node.key);
// 递归在左子树查找
if (cmp < 0) {
return get(node.left, key);
// 递归在右子树查找
} else if (cmp > 0) {
return get(node.right, key);
// 查找命中返回值
} else {
return node.value;
}
}
看上图,分别是查找命中和未命中的轨迹。
get方法可以使用非递归实现,通常性能更佳。
public Value get(Key key) {
Node cur = root;
while (cur != null) {
int cmp = key.compareTo(cur.key);
if (cmp < 0) {
cur = cur.left;
} else if (cmp > 0) {
cur = cur.right;
} else {
return cur.value;
}
}
return null;
}
插入
put
方法和get
方法如出一辙,也是采用了递归的方式:从根结点开始如果树是空的,就返回一个含有该键值对的新结点;如果被查找的键小于根结点的键,就在其左子树中插入该键,否则在右子树插入该键。
public void put(Key key, Value value) {
// 更新root
// 第一次put:本来null的root = new Node
// 以后的put:root = root
root = put(root, key, value);
}
private Node put(Node node, Key key, Value value) {
if (node == null) {
return new Node(key, value, 1); // 新结点size当然是1
}
int cmp = key.compareTo(node.key);
// 在node的左子树插入
if (cmp < 0) {
node.left = put(node.left, key, value);
// 在node的右子树插入
} else if (cmp > 0) {
node.right = put(node.right, key, value);
// 键已经存在,更新
} else {
node.value = value;
}
// 插入后更新以node为根的子树总结点数
node.N = size(node.left) + size(node.right) + 1;
// 除了第一次put返回新结点外,都是返回root
return node;
}
注意递归算法中有返回值,插入时是从上往下的查找,然后在树的底部插入,然后在递归方法返回的过程中,是自下而上逐渐更新查找路径上的每个结点node的,包括node.left
或者node.right
,及node.N
,所以在公有(public)的 get方法中有root = put(root, key, value)
,表示更新后的root传递给原root。
看上图,递归查找和get方法一样,直到遇到空树,在这里是M的左子结点,然后新建结点接到M的左子树上。之后是递归方法的返回,在返回过程中不断更新了每个结点的左子结点、右子结点、以此结点为根的子树总结点数(实际上很多结点的这些值并没有变化,但这些操作又是必须的)
最大/最小键
如果根结点的左子结点为空,那么该根结点就是最小的键;如果左子结点不为空,那么一直沿着左链接深入,直到遇到某个结点没有左子结点了,那么此时该结点的键就是最小的。最大键的实现就是不断深入右子树,直到某结点的右子结点为空。在代码中将left换成right即可。
// 递归实现min
public Key min() {
return min(root).key;
}
// 递归实现max
public Key max() {
return max(root).key;
}
private Node min(Node node) {
if (node.left == null) {
return node;
} else {
return min(node.left);
}
}
private Node max(Node node) {
if (node.right == null) {
return node;
} else {
return max(node.right);
}
}
当然可以采用非递归的版本。
public Key min() {
Node node = root;
while (node.left != null) {
node = node.left;
}
return node.key;
}
public Key max() {
Node node = root;
while (node.right != null) {
node = node.right;
}
return node.key;
}
向上/向下取整
floor(Key key)
返回小于等于key的键;ceiling(Key key)
返回大于等于key的键。
floor方法:如果key等于根结点的键那么直接返回根结点的键;如果key小于根结点,则小于等于key的最大键一定在根结点的左子树中;如果key大于根结点,那么必须当右子树中存在小于等于key的结点时,小于等于key的键才存在于右子树中,若不存在则小于等于key的键就是根结点本身。
这两个方法是镜像的,理解了floor就将能顺理成章写出ceiling。
则ceiling方法就是:key等于根结点的键那么直接返回根结点的键;如果key大于根结点,则大于等于key的最大键一定在根结点的右子树中;如果key小于根结点,那么必须当左子树中存在大于等于key的结点时,大于等于key的键才存在于左子树中,若不存在则大于等于key的键就是根结点本身。
实现如下
public Key floor(Key key) {
Node node = floor(root, key);
if (node == null) {
return null;
} else {
return node.key;
}
}
private Node floor(Node node, Key key) {
if (node == null) {
return null;
}
int cmp = key.compareTo(node.key);
// 和根结点相等直接返回根结点
if (cmp == 0) {
return node;
// 比根结点小,肯定在左子树中
} else if (cmp < 0) {
return floor(node.left, key);
// 比根结点大,若在右子树中就返回右子树相应结点,否则就是根结点本身
} else {
Node temp = floor(node.right, key);
if (temp != null) {
return temp;
} else {
return node;
}
}
}
public Key ceiling(Key key) {
Node node = ceiling(root, key);
if (node == null) {
return null;
} else {
return node.key;
}
}
private Node ceiling(Node node, Key key) {
if (node == null) {
return null;
}
int cmp = key.compareTo(node.key);
// 和根结点相等直接返回根结点
if (cmp == 0) {
return node;
// 比根结点大,肯定在右子树中
} else if (cmp > 0) {
return ceiling(node.right, key);
// 比根结点小,若在左子树中就返回左子树相应结点,否则就是根结点本身
} else {
Node temp = ceiling(node.left, key);
if (temp != null) {
return temp;
} else {
return node;
}
}
}
查找G,开始时G < S,所以小于等于G的最大键肯定在S的左子树中,然后G > E,则小于等于G的最大键可能存在于E的右子树中,经查找后不存在小于等于G的键,所以最后返回的是根结点E。
由于floor和ceiling方法实现十分相似,如果理解了floor的查找轨迹,ceiling也应该不在话下。
选择和排名
select(k)
:假设我们想知道排名为k的键是什么(即树中正好有k个键小于它)。如果左子树中的结点数t大于k,那么继续递归地在左子树中查找排名为k的键;如果t等于k,就返回根结点的键(根结点的左子树结点总数刚好就是根结点的排名),如果t小于k,得在右子树递归地查找排名为k - t -1的键(因为左子树结点个数为t,加上根结点1,共t + 1个,而k - t - 1+ t + 1 = k)依然能保证查找到的是排名为k的键。
rank(Key key)
:此方法可返回给定键的排名。是select方法的逆方法。如果给定键和根结点的键相同,就返回左子树的结点数(根结点左子树的结点数刚好是根结点的排名);如果给定的键小于根结点,递归运算返回该键在左子树中的排名;如果给定的键大于根结点,返回t + 1 + 该键在右子树中的排名
(t + 1是根结点的左子树及根结点,所以三者加起来才是该键的正确排名)
public Key select(int k) {
if (k < 0 || k >= size()) {
throw new IllegalArgumentException("argument to select() is invalid: " + k);
}
return select(root, k).key;
}
private Node select(Node node, int k) {
if (node == null) {
return null;
}
int t = size(node.left);
// 左子树的结点数大于k,继续在左子树查找
if (t > k) {
return select(node.left, k);
// 左子树结点数小于k,得在右子树查找
} else if (t < k) {
return select(node.right, k - t - 1);
// 左子树的结点数刚好等于k,找到,排名为k的就是这个根结点
} else {
return node;
}
}
public int rank(Key key) {
return rank(root, key);
}
private int rank(Node node, Key key) {
if (node == null) {
return 0;
}
int cmp = key.compareTo(node.key);
// 比根结点小,应该在左子树中继续查找
if (cmp < 0) {
return rank(node.left, key);
// 比根结点大,应该在右子树中查找,算排名时加上左子树和根结点的结点总和
} else if (cmp > 0) {
return 1 + size(node.left) + rank(node.right, key);
// 和根结点相等,找到,排名就是其左子树结点总数
} else {
return size(node.left);
}
}
S的左子树结点个数为6,大于3,所以在S的左子树中继续查找,E的左子树结点个数为2,小于3;所以应该在E的右子树中查找排名为k - t - 1 = 3 - 2 - 1 = 0的结点。R的左子树结点个数为2,大于0,应该在R的左子树中查找;H的左子树结点个数为0,且正在查找排名为0的结点,返回H。看图中,有序键为ACEHMRSX,H确实是排名3。
再看rank(O)
,还是用上面的图,O小于S所以在左子树中查找,O大于E,转右子树,O在右子树中的排名是2(HMR中有H和M小于O),则最后返回1 + size(e.left) + 2 = 5,三个值分别是结点e、结点e的左子树结点个数、O在右子树中的排名。
删除
删除最小/最大键
先看简单的情况,删除最小最大键,其实思路和查找最小最大键类似。也是不断深入左子树,直到某个结点没有左子结点,现在要做的就是删除该结点,比如该结点为x,其父结点为t,有t.lelf == x。只要使x的右结点(不管是不是空)成为t的新的左结点即可,也就是t.left = x.right
,原左结点会被垃圾回收,达到删除的目的。删除最小键的操作轨迹如下图左边所示。
删除最大键是删除最小键的镜像问题,就不赘述了。
public void deleteMin() {
root = deleteMin(root);
}
private Node deleteMin(Node node) {
if (node.left == null) {
return node.right;
}
// 其实就是node.left = node.left.right
node.left = deleteMin(node.left);
node.N = size(node.left) + size(node.right) + 1;
return node;
}
public void deleteMax() {
root = deleteMax(root);
}
private Node deleteMax(Node node) {
if (node.right == null) {
return node.left;
}
node.right = deleteMax(node.right);
node.N = size(node.left) + size(node.right) + 1;
return node;
}
删除任意键
如果要删除的键只有一个子结点或者没有子结点,可以按照上述方法删除,但是如果要删除的结点既有左子结点又有右子结点呢?删除后将要同时处理两棵子树,但是被删除结点的父结点只会空出一条链接出来。换个角度想想,二叉查找树的中序遍历序列就是有序键的集合,所以删除了该结点,可以用该结点的后继或者前驱结点取代它。这里我们打算用后继结点取代被删除结点的位置。具体步骤如下
- 如果被删除的结点只有一个子结点或者没有子结点,比如被删除结点为x,其父结点为t。若x没有左结点则
t.left = x.right
,或者x没有右结点则t.right = x.left
。 - 如果被删除的结点有左右子结点。先将被删除的结点保存为t,其右子结点为t.right,然后找到右子树中的最小结点,该结点就是被删除结点t的后继结点,设为x。t和m之间再无其他键,所以m取代t的位置后,剔除m后的t的右子树中所有结点仍然大于m,所以只需让m的右子树连接剔除m后的t的右子树,m的左子树连接t的左子树即可。
比如下图右边删除结点E,E的左右子结点都不为空,E的右子结点是R,然后在子树R中找到最小值H,H就为E的后继结点。然后H取代E的位置,剔除H(调用deleteMin(R)即可)后的E的右子树还剩下R、M,让H的右子树和他们相连,再让H的左子树和E的左子树相连,OK~
根据描述写出如下代码
private Node delete(Node node, Key key) {
if (node == null) {
return null;
}
int cmp = key.compareTo(node.key);
// key大于当前根结点,在右子树查找
if (cmp > 0) {
node.right = delete(node.right, key);
// key小于当前根结点,在左子树查找
} else if (cmp < 0) {
node.left = delete(node.left, key);
// 找到给定的key
} else {
// 如果根结点只有一个子结点或者没有子结点,按照删除最小最大键的做法即可
if (node.left == null) {
return node.right;
}
if (node.right == null) {
return node.left;
}
// 根结点的两个子结点都不为空
// 要删除的结点用t保存
Node t = node;
// t的后继结点取代t的位置
node = min(t.right);
node.right = deleteMin(t.right);
node.left = t.left;
}
node.N = size(node.left) + size(node.right) + 1;
return node;
}
范围查找
要查找某个范围内的所有键,首先需要一个遍历二叉树所有结点的方法,我们多次提到二叉查找树的中序遍历序列就是有序键的集合。所以得到如下思路:中序遍历二叉查找树,如果该键落在范围内,加入到集合中。当然如果某个根结点的键小于该范围的最小值,其左子树肯定也不会在范围内;同样某个结点的键大于该范围的最大值,其右子树肯定也不会在范围内。这两种情况都无需递归遍历了,直接跳过。所以为了减少比较操作,在递归遍历前加上判断条件。
public Set<Key> keys() {
return keys(min(), max());
}
public Set<Key> keys(Key low, Key high) {
Set<Key> set = new LinkedHashSet<>();
keys(root, set, low, high);
return set;
}
private void keys(Node node, Set<Key> set, Key low, Key high) {
if (node == null) {
return;
}
int cmplow = low.compareTo(node.key);
int cmphigh = high.compareTo(node.key);
// 当前结点比low大,左子树中可能还有结点落在范围内的,所以应该遍历左子树
if (cmplow < 0) {
keys(node.left, set, low, high);
}
// 在区间[low, high]之间的加入队列
if (cmplow <= 0 && cmphigh >= 0) {
set.add(node.key);
}
// 当前结点比high小,右子树中可能还有结点落在范围内,所以应该遍历右子树
if (cmphigh > 0) {
keys(node.right, set, low, high);
}
}
结合一个图例理解下keys方法,图中是查找[F, T]
范围内的所有键。首先S大于F,在左子树中查找,发现E在范围外,跳过E及其子树;回到E的右子树R,R大于F,会在左子树中继续查找,H在范围内,所以加入到集合中,然后到H的右子树....以此类推,最后被加入到集合的元素有HLMPRS(中序遍历得到的,所以有序)
values
的实现和keys
完全类似,不再赘述。
至于求某范围内的键的个数size(Key low, Key high)
。在有序数组的二分查找中已经有实现,直接拿过来用。如下
public int size(Key low, Key high) {
if (high.compareTo(low) < 0) {
return 0;
}
if (contains(high)) {
return rank(high) - rank(low) + 1;
} else {
return rank(high) - rank(low);
}
}
代码测试
先重写toString
,格式化打印所有键值对。
@Override
public String toString() {
Iterator<Key> keys = keys().iterator();
Iterator<Value> values = values().iterator();
if (!keys.hasNext()) {
return "{}";
}
StringBuilder sb = new StringBuilder();
sb.append("{");
while (true) {
Key key = keys.next();
Value value = values.next();
sb.append(key).append("=").append(value);
if (!keys.hasNext()) {
return sb.append("}").toString();
}
sb.append(", ");
}
}
来测试下代码。
public static void main(String[] args) {
BST<Integer, Double> st = new BST<>();
st.put(1, 5567.5);
st.put(5, 10000.0);
st.put(3, 4535.5);
st.put(7, 7000.0);
st.put(12, 2500.0);
st.put(10, 4500.0);
st.put(17, 15000.5);
st.put(15, 12000.5);
st.deleteMax(); // 17
st.deleteMin(); // 1
st.delete(12); // 剩下[3, 5, 7, 10, 15]
System.out.println("符号表的长度为" + st.size());
System.out.println("[3, 6]之间有" + st.size(3, 6) + "个键");
System.out.println("比9小的键的数量为" + st.rank(9));
System.out.println("排在第4位置的键为" + st.select(4));
System.out.println("大于等于8的最小键为" + st.ceiling(8));
System.out.println("小于等于8的最大键为" + st.floor(8));
System.out.println("符号表所有的键和对应的值为:" + st.keys() + " -> " + st.values());
System.out.println("键2和键8之间的所有键及对应的值:" + st.keys(2, 8) + " -> " + st.values(2, 8));
System.out.println(st);
/*
符号表的长度为5
[3, 6]之间有2个键
比9小的键的数量为3
排在第4位置的键为15
大于等于8的最小键为10
小于等于8的最大键为7
符号表所有的键和对应的值为:[3, 5, 7, 10, 15] -> [4535.5, 10000.0, 7000.0, 4500.0, 12000.5]
键2和键8之间的所有键及对应的值:[3, 5, 7] -> [4535.5, 10000.0, 7000.0]
{3=4535.5, 5=10000.0, 7=7000.0, 10=4500.0, 15=12000.5}
*/
}
by @sunhaiyu
2017.10.17