Given a binary search tree, write a function kthSmallest
to find the kth
smallest element in it.
**Note: **
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.
Example 1:
Input: root = [3,1,4,null,2], k = 1
3
/ \
1 4
\
2
Output: 1
Example 2:
Input: root = [5,3,6,2,4,null,null,1], k = 3
5
/ \
3 6
/ \
2 4
/
1
Output: 3
Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?
思路1 : 最小堆
Nth Smallest/ Top Kth Element都属于用PriorityQueue可以解决问题的范畴。
- 遍历树,用最小堆来存K个node。
- 最小堆顶的那个node就是Nth Smallest
思路2: in-order traversal
对于BST,in-order遍历时,第K个访问的节点就是结果。
1. PriorityQueue Code
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
class Solution {
public int kthSmallest(TreeNode root, int k) {
if (root == null) return 0;
Comparator<TreeNode> cmp = new Comparator<TreeNode>() {
public int compare(TreeNode i1, TreeNode i2) {
return i2.val - i1.val;
}
};
PriorityQueue<TreeNode> queue = new PriorityQueue<TreeNode>(k, cmp);
traverseTree(root, queue, k);
return queue.peek().val;
}
private void traverseTree(TreeNode root, PriorityQueue<TreeNode> priorityQueue, int k) {
Queue<TreeNode> queue = new LinkedList<TreeNode>();
queue.add(root);
while (!queue.isEmpty()) {
TreeNode curNode = queue.poll();
priorityQueue.add(curNode);
if (priorityQueue.size() > k) {
priorityQueue.poll();
}
if (curNode.left != null) {
queue.add(curNode.left);
}
if (curNode.right != null) {
queue.add(curNode.right);
}
}
}
}
2. In-Order Traversal Code
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
class Solution {
public int kthSmallest(TreeNode root, int k) {
if (root == null) {
return 0;
}
int[] counter = { k };
return getTarget (root, counter).val;
}
private TreeNode getTarget (TreeNode root, int[] counter) {
TreeNode target = null;
if (root.left != null) {
target = getTarget (root.left, counter);
}
if (target == null) {
if (counter[0] == 1) {
target = root;
}
counter[0] --;
}
if (target == null && root.right != null) {
target = getTarget (root.right, counter);
}
return target;
}
}