Kth smallest/minimum element in a BST – Rank of a BST node

Given a binary search tree. Find the kth smallest element in the BST.

A quick solution would be to perform a modified inorder traversal with an extra parameter k. Each time inorder traversal is popping a node out of recursion/call stack (i.e. unwinding a recursion)then we keep decreasing the k. When k=0 then the current node in the call stack is the desired kth smallest node. This is O(n) time algorithm.

For example, for the following BST the nodes in inorder traversal are: 1, 2. 3, 4, 5, 6.  So, the 3rd smallest is 3.

                                       
                                    4
                                 /     \
                                2       5 
                              /   \       \
                             1     3       6

However with some augmented data in our tree and doing some inexpensive bookkeeping during build up of the tree, we can find kth smallest in O(lgn) time. Observe that,

Number of smaller nodes than a node is equal to the size of the left subtree.

This leads us to assign rank to each of the node. Notice that

              size(node) = size(node.left)+size(node.right)+1;
              rank(node) = size(node.left)+1

So, we can either update the size left subtree of each node during building phase of the tree or during a later phase. Now it would be easy to find kth smallest using a binary search on the tree using the fact that the kth smallest element has rank of k. If the middle element’s rank is less than k then we search only in left subtree or vice versa. Below is the O(lgn) implementation of this idea with the BST augmented with size and updating the size during each insert:

public static TreeNode insert(final TreeNode root, final TreeNode node) {
    if (root == null || node == null) {
        return node;
    }

    TreeNode current = root;
    TreeNode parent = root;
    while (current != null) {
        if (current.key > node.key) {
            current.lcount++;
            current.size++;
            parent = current;
            current = current.left;
        } else {
            current.size++;
            parent = current;
            current = current.right;
        }
    }

    if (parent.key > node.key) {
        parent.left = node;
    } else {
        parent.right = node;
    }

    return root;
}

public static TreeNode kthSmallestElement(final TreeNode root, final int k) // select(k)
{
    final int rank = root.lcount + 1;

    if (rank == k) {
        return root;
    } else if (rank > k) {
        return kthSmallestElement(root.left, k);
    } else {
        return kthSmallestElement(root.right, k - rank);
    }
}

 

Calculating rank during later phase

public static int OS_RANK(final TreeNode x) {
    int r = x.left.size + 1;
    TreeNode y = x;
    while (y.parent != null) {
        if (y.parent.right == y) {
            r += (y.parent.left != null ? y.parent.left.size : 0) + 1;
        }
        y = y.parent;

    }
    return r;
}