Kth Smallest Element in a BST

  • medium
  • tag: BST, recursion
  • similar: kth smallest, traversal in/pre/post order

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

  • Note: You may assume k is always valid, 1 ≤ k ≤ BST's total elements.
  • Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

思路

  • 最先的想法是直接in order成array, 输出arr[k].
  • 或者是recursion的找左右子树的node个数

idea 1: 找node个数来break down

  • 代码如下:
public int kthSmallestCount(TreeNode root, int k) {
    int cnt = countNodes(root.left);
    if (k < cnt+1) {
      return kthSmallestCount(root.left, k);
    }
    else if (k > cnt+1) {
      return kthSmallestCount(root.right, k-cnt-1);
    }
    return root.val;
  }

private int countNodes(TreeNode root) {
    if (root == null) {
      return 0;
    }
    int left = countNodes(root.left);
    int right = countNodes(root.right);
    return left+right+1;
}

idea 2: 直接in order recursion计数

  • 参考link
  • 这个是最快的, 只要O(h), 即树的高度. 而且这个有很多实现方式, 光是recursion就有很多种. 而且可以变成iteration来写.

idea 2.1: recursion的容易写法

  • 好好掌握recursion啊!
  • 代码如下:
public int kthSmallest(TreeNode root, int k) {
    int[] cnt = new int[]{k};
    int[] val = new int[]{0};
    helper(root, cnt, val);
    return val[0];
  }

private void helper(TreeNode root, int[] cnt, int[] val) {
    if (root.left != null)  helper(root.left, cnt, val);
    cnt[0]--;
    if (cnt[0] == 0) {
      val[0] = root.val;
      return;
    }
    if (root.right != null)  helper(root.right, cnt, val);
}

idea 2.2: fastest recursion

  • 这个recursion要好好理解, 下面的代码已经加上了tom47的注释.
    • 要注意的是:
      • left得到了!null的值, 就说明已经是bottom-up过了, 即root/right的返回值. 因为null是当top-down的时候就return了.
      • 例如某一个subtree的右子树return了right, 但其实是root调用的rec(root.left)返回而已.
      • 好好去理解recursion call, 以及前后代码与该调用的关系.
  • 参考link
  • 代码如下:
private int counter;
  public int kthSmallestFast(TreeNode root, int k) {
    counter = 0;
    return kthSmallestNodeFast(root, k).val;
  }

private TreeNode kthSmallestNodeFast(TreeNode root, int k) {
    if (root == null) 
        return null; // if the root is null, no where else to go there is no answer, return null
    TreeNode left = kthSmallestNodeFast(root.left, k); // traverse all the left sub tree
    if (left != null) 
        return left; // if the left was not null, that means the line below must have ran on some stack, the non null node we'll be getting must be the answer. If it is null, that means the counter never == k,  that means there wasnt enough values in the left subtree.
    if (++counter == k) 
        return root; //if counter==k, we visited k nodes, by bst and in order traversal return root (the answer)
    return kthSmallestNodeFast(root.right, k); // if root is not it, and left is not it, return what ever we find in right, it might be null or it might be the correct node. Again if it returns null, that means the entire right tree + left tree did not have enough nodes to be == k
}