Median (or kth Smallest) of a sorted 2D array

Given a 2D array with rows sorted in ascending order. Find the median of the whole 2D array.

For example,

             A= 2, 4, 5, 6
                1, 2, 2 ,4
                3, 4, 4, 5
                1, 2 , 3, 3

Then the merged array would be [1, 1, 2, 2 ,2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 6]. So, 5th smallest element would be 2 and 10th smallest is 4 etc. So, median is (3+3)/2 = 3.

As, rows are sorted, it looks like if we can merge the sorted arrays row by row then we could easily find the median. This is a trivial solution. Do a merge operation (as we do in the merge phase of merge sort) on each pair of row successively. Then median is the m*n/2 the element id m*n is odd or if its even the media is the average of (m*n/2-1, m*n/2+1)th element. Complexity of this solution is O(nmlg(n)). Can we do better?

Yes, There is a better way to achieve this if we can reformulate the problem. We can find the kth smallest element from the 2D array by using a min heap. Then median will be n/2th smallest element. (or avg of n/2-1 and n/2+1 th).

For example, we can start with building a min heap by inserting first column of the 2D array in O(nlgn) time. Now, at each iteration extract the min from the heap and insert the next element from the row of the element extracted (if the min is at the end of the row then no insert). This ensures we are traversing in ascending order. That is , the extracted element at kth iteration will be the kth smallest element. The complexity for such traversal is O(klgn) if k>n. We will use a data structure to keep row and col index of an element along with the value of the element to be aded into the heap.

For example, Lets find 5th smallest from A. Let’s assume A[i][j] is kept as {A[i][j],i,j} in the heap. We first build heap from A[0..n-1][0], minHeap = [{1,3,0}, {1,1,0}, {2,0,0}, {3,2,0}]. Now, we start removing min element {1,3,0} from the heap and insert next element (col=1) from the same row (=3) of the min element. Now, minHeap=[{1,1,0}, {2,0,0}, {2,3,1}, {3,2,0}]. In subsequent iterations we remove {1, 1, 0} and add {2, 1, 1}, remove {2, 0, 0} and add {4, 0, 1}, remove {2, 3, 1} and add {3, 3, 2}, remove {2, 1, 1} and add {2, 1, 2}. Now, we removed 4 mins and so the last removed min was 4th min. Now, at k=th iteration, we remove element {2, 1, 2} and add {4, 1, 3}. So, kth smallest element is 2 at A[1][2].

Below is the O(klgn) implementation of finding median by using kthSmallest in the sorted 2D arrray (when k>n). if k

public static int median(int[][] A){
	int n = A.length;
	int m = A[0].length;
	
	if((n*m)%2 == 0){
		int mid1 = kthSmallestElement(A, n/2-1);
		int mid2 = kthSmallestElement(A, n/2+1);
		return (mid1+mid2)/2;
	}
	else{
		return kthSmallestElement(A, n/2);
	}
}

public static int kthSmallestElement(int[][] A, int k){
	int n = A.length;
	int m = A[0].length;
	MatrixElement kthSmallest = null;
	
	PriorityQueue<MatrixElement> minHeap = new PriorityQueue<MatrixElement>();
	
	//add column 0 into meanHeap - O(nlgn)
	for(int i = 0; i<n; i++){
		minHeap.offer(new MatrixElement(A[i][0], i, 0));
	}
	
	//extract min from minheap and insert next element from the same row of the extracted min
	int count = 0;
	while(!minHeap.isEmpty() && count < k){
		kthSmallest = minHeap.poll();
		count++;
		//
		if(kthSmallest.col+1 < m){
			minHeap.offer(new MatrixElement(A[kthSmallest.row][kthSmallest.col+1], kthSmallest.row, kthSmallest.col+1));
		}
	}
	
	return kthSmallest.val;
}

public static class MatrixElement implements Comparable<MatrixElement>{
	public int val;
	public int row;
	public int col;
	
	public MatrixElement(int val, int row, int col){
		this.val = val;
		this.row = row;
		this.col = col;
	}
	@Override
	public int compareTo(MatrixElement o) {
		return Integer.compare(this.val, o.val);
	}
}