Kth smallest/largest element in an array

Given an array of integer. Find the kth smallest element in the array in a most efficient manner.

For example: A = [2, 1, 0, 3, -1, 3] and k=3 then the 3rd smallest element is 1. This is also (6-3) = 3rd largest element.

A trivial solution is to sort the array. But question is how you will sort the array? Quicksort? how does Quicksort work? What is the complexity? O(nlgn) right? What if n is very very large and we only care about kth smallest? We will see how we can achieve this by using a selection based algorithm such as QuickSort. We can also find minimum, maximum, and median elements. This is called kth order statistics.

The idea is to recursively partition the array with respect to a pivot such that the left side of the partition has elements less than the pivot and right side contains elements greater than or equal to the pivot. As the elements on the left is less than the value of pivot, so the pivot would be the kth smallest element if and only if there are exactly k elements on the left including the pivot.

  1. Select an initial pivot. This is an important step as the worst case performance of the algorithm depends on the selection of a partition that evenly distributes the elements of the array across the pivot. For simplicity we used the middle element of the array as the pivot.But we could do better. For example, selecting the pivot as the median of medians can lead to a worst case O(n) time lookup. Please read this article on Median of medians to find out how to find median of medians in O(n) time. The idea is very simple: partition the array into 5 contagious subarrays and find median of each of the subarray by using the median as the pivot into a quick select algorithm. Then the median of medians is simply the median element of all the 5 medians.
  2. Perform an in-place partitioning of the array around the pivot such that elements less than the pivot are in left side of the pivot and elements greater than or equal to the pivot are in the right side of the pivot.
  3. Check if the number of elements on the left including the pivot (lets say m) is equal to k or not (m==k?).
    1. If yes, then the pivot is the kth smallest element. We return this element.
    2. Else if m > k then we look for the kth smallest on the left side of the array recursively using the same procedure described above.
    3. Else if m < k then there are m elements on the left that are less than kth smallest. So, we need to lookup for the (k-m)th smallest element on the right side of the array because there are already m smaller elements on the left. We follow the same procedure above recursively.

For example, with A = [2, 1, 0, 3, -1, 3], n=6, l=0, r=n-1=5, and k=4. Lets pivot index, q = (l+r)/2 = 2. So, partition the array around pivot as follows:

Swap pivot (=A[q]=A[2]=1) with the right most. 
Keep a pointer i to keep track of the ending of left partition. Initially i = l-1. 
Keep another pointer j to scan the array from l to r-1. 

         A =  |2, 0, 1, 3, -1, |3
             ^ ^     ^          ^
             i j     q          r

Scan the array from j = l to r-1. 
If we find an element less than the pivot (= 1), then 
we swap it with the left most element of the right partition (at index i+1),
so that we can extend the left side boundary to i+1.

         A =  | 2, 0, 3, 3, -1, |1 
             ^  ^                ^ 
             i  j                r 

A[j]>pivot, so j++

         A =   | 2, 0, 3, 3, -1, |1
              ^     ^             ^
              i     j             r

A[j]<pivot, so swap(i+1, j) and i++, j++

         A =   0, | 2, 3, 3, -1, |1
               ^       ^          ^
               i       j          r 

A[j]>pivot, so j++

         A =   0, | 2, 3, 3, -1, |1
               ^          ^       ^
               i          j       r 

A[j]>pivot, so j++

         A =   0, | 2, 3, 3, -1, |1
               ^              ^   ^
               i              j   r 

A[j]<pivot, so swap(i+1, j) and i++. Stop scan as j==r-1

         A =   0, -1, | 3, 3, 2, |1
                   ^          ^   ^
                   i          j   r 

Now, move the pivot element between the two partitions,
i.e. swap(i+1, q) and return final pivot index, q = i+1.

         A =  0, -1, 1, 3, 2, 3
                     ^ 
                     q                              

So, we have m = 3 elements on the left of current pivot index (q = 2) including q. As k=4, m<k so we look into the right partition A[q+1..r-1]= [3, 2, 3] for k-m = 4-3 = 1st smallest element. We will use similar partitioning procedure on the right partition described above to find 2nd smallest element.

Below is a simple implementation of the above algorithm described to find the kth smallest element in O(n) time and constant space. Note that, this is an in-place algorithm but not stable as order is not maintained among elements.

private static void swap(final int input[], final int i, final int j) {
    final int temp = input[i];
    input[i] = input[j];
    input[j] = temp;
}

private static int partition(final int[] A, final int p, final int r) {
    final double pivot = A[r];
    int i = p - 1;
    int j = p;

    for (j = p; j < r; j++) {
        if (A[j] <= pivot) {
            swap(A, ++i, j);
        }
    }

    swap(A, i + 1, r);
    return i + 1;
}

private static int RandomizedPartition(final int[] A, final int p, final int r) {
    final int i = (int) Math.round(p + Math.random() * (r - p));
    swap(A, i, r);
    return partition(A, p, r);
}

public static int kthSmallest(final int[] A, final int p, final int r, final int k) {
    if (p < r) {
        final int q = RandomizedPartition(A, p, r);

        final int n = q - p + 1;
        if (k == n) {
            return A[q];
        } else if (k < n) {
            return kthSmallest(A, p, q - 1, k);
        } else {
            return kthSmallest(A, q + 1, r, k - n);
        }
    } else {
        return Integer.MIN_VALUE;
    }
}

We can use the same procedure to find kth largest element by finding (n-k)th smallest element (i.e. kth smallest element from right). This algorithm can also be used to find median of the array by finding k(=n/2)th smallest element of the array.