Binary Search Tree (BST) insert, delete, successor, predecessor, traversal, unique trees

From wiki, A binary search tree is a rooted binary tree, whose internal nodes each store a key (and optionally, an associated value) and each have two distinguished sub-trees, commonly denoted left and right. The tree additionally satisfies the binary search tree property, which states that the key in each node must be greater than all keys stored in the left sub-tree, and smaller than all keys in right sub-tree.[1] (The leaves (final nodes) of the tree contain no key and have no structure to distinguish them from one another. Leaves are commonly represented by a special leaf or nil symbol, a NULL pointer, etc.)

package test;

public class TreeNode {
    int key;
    int height;
    int size;
    TreeNode left;
    TreeNode right;
    TreeNode parent;
    Object data = null;

    public TreeNode(final int key) {
        this.key = key;
        this.size = 1;
        this.height = 1;
        this.left = null;
        this.right = null;
    }
    
    public TreeNode(final int key, Object val) {
        this.key = key;
        this.size = 1;
        this.height = 1;
        this.left = null;
        this.right = null;
        this.data = val;
    }
    
    //insert an element - iterative
    public TreeNode insert(TreeNode root, int key, int val){
    	TreeNode node = new TreeNode(key, val);
    	TreeNode parent = null;
    	if(root == null){
    		return node;
    	}
    	
    	while(root != null){
    		parent = root;
    		if(key < root.key){
    			root = root.left;
    		}
    		else{
    			root = root.right;
    		}
    		parent.size++;
    	}
    	
    	if(key < parent.key){
    		parent.left = node;
    		node.parent = parent;
    	}
    	else{
    		parent.right = node;
    		node.parent = parent;
    	}
    	
    	return node;
    }
    
    //insert an element - recursive
    public static TreeNode insert2(TreeNode root, int key, int val){
    	if(root == null){
    		return new TreeNode(key, val);
    	}
    	
    	if(root.key > key){
    		root.left = insert2(root.left, key, val);
    	}
    	else if(root.key < key){
    		root.right = insert2(root.right, key, val);
    	}
    	else{
    		root.data = val;
    	}
    	
    	root.size = 1+size(root.left)+size(root.right);
    	return root;
    }
    
    //max key element of the tree
    public static TreeNode max(TreeNode root){
    	if(root == null){
    		return null;
    	}
    	
    	while(root.right != null){
    		root = root.right;
    	}
    	
    	return root;
    }
    
   //min key element of the tree
    public static TreeNode min(TreeNode root){
    	if(root == null){
    		return null;
    	}
    	
    	while(root.left != null){
    		root = root.left;
    	}
    	
    	return root;
    }
    
    //inorder successor using parent node
    public static TreeNode successor(TreeNode node){
    	if(node.right != null){
    		return min(node.right);
    	}
    	else{
    		TreeNode parent = node.parent;
    		
    		while(parent != null){
    			//without using key comparison -- only using left, right pointer
    			//if(node == parent.right){
    			//	node = parent;
    			//}
    			//else break;    			
    			if(parent.key > node.key){
    				break;
    			}
    			
    			parent = parent.parent;
    		}
    		
    		return parent;
    	}
    }
    
    //inorder successor without using parent node
    public static TreeNode successor2(TreeNode root, TreeNode node){
    	if(node.right != null){
    		return min(node.right);
    	}
    	else{
    		TreeNode successor = null;
    		
    		while(root != null){
    			if(root.key > node.key){
    				successor = root;
    				root = root.left;
    			}
    			else if(root.key < node.key){
    				root = root.right;
    			}
    			else{
    				break;
    			}
    		}
    		
    		return successor;
    	}
    }
    
    //inorder predecessor using parent node
    public static TreeNode predecessor(TreeNode node){
    	if(node.left != null){
    		return max(node.left);
    	}
    	else{
    		TreeNode parent = node.parent;
    		
    		while(parent != null){
    			//without using key comparison -- only using left, right pointer
    			//if(node == parent.left){
    			//	node = parent;
    			//}
    			//else break;    		
    			if(parent.key < node.key){
    				break;
    			}
    			
    			parent = parent.parent;
    		}
    		
    		return parent;
    	}
    }
    
    //inorder predecessor without using parent node
    public static TreeNode predecessor2(TreeNode root, TreeNode node){
    	if(node.left != null){
    		return max(node.left);
    	}
    	else{
    		TreeNode pred = null;
    		
    		while(root != null){
    			if(root.key > node.key){
    				root = root.left;
    			}
    			else if(root.key < node.key){
    				pred = root;
    				root = root.right;
    			}
    			else{
    				break;
    			}
    		}
    		
    		return pred;
    	}
    }
    
    //delete without using parent 
    public static TreeNode delete(TreeNode root, int key){
    	if(root == null){
    		return root;
    	}
    	
    	if(key < root.key){
    		root.left = delete(root.left, key);
    	}
    	else if(key > root.key){
    		root.right = delete(root.right, key);
    	}
    	else{
    		if(root.left == null){
    			TreeNode temp = root.right;
    			root = null;
    			return temp;
    		}
    		else if(root.right == null){
    			TreeNode temp = root.left;
    			root = null;
    			return temp;
    		}
    		
    		TreeNode successor = min(root.right);
    		root.key = successor.key;
    		root.right = delete(root.right, successor.key);
    	}
    	
    	root.size = size(root.left)+size(root.right)+1;
    	return root;
    }
    
    //delete using parent
    public static void delete2(TreeNode node){
    	if(node == null){
    		return;
    	}
    	
    	if(node.left == null && node.right == null){
    		if(node == node.parent.left){
    			node.parent.left = null;
    		}
    		else{
    			node.parent.right = null;
    		}
    		node = null;
    	}
    	else if(node.left == null || node.right == null){
    		TreeNode parent = node.parent;
    		node = node.left == null ? node.right : node.left;
    		node.parent = parent;
    	}
    	else{
    		TreeNode successor = successor(node);
    		node.key = successor.key;
    		delete2(successor);
    	}
    }
    
    private static int size(TreeNode node){
    	return node == null ? 0 : node.size;
    }
    
    //largest key less than equal to given key
    public static TreeNode floor(TreeNode root, int key){
    	if(root ==  null){
    		return root;
    	}
    	
    	if(root.key > key){
    		return floor(root.left, key);
    	}
    	else if(root.key < key){
    		TreeNode floor = floor(root.right, key);
    		if(floor == null){
    			return root;
    		}
    		else{
    			return floor;
    		}
    	}
    	else{
    		return root;
    	}
    }
    
    //smallest key greater than equal to given key
    public static TreeNode ceiling(TreeNode root, int key){
    	if(root ==  null){
    		return root;
    	}
    	
    	if(root.key < key){
    		return ceiling(root.right, key);
    	}
    	else if(root.key > key){
    		TreeNode floor = ceiling(root.left, key);
    		if(floor == null){
    			return root;
    		}
    		else{
    			return floor;
    		}
    	}
    	else{
    		return root;
    	}
    }
    
    //select kth smallest element in the BST
    public static TreeNode select(TreeNode root, int k){
    	if(root == null){
    		return root;
    	}
    	
    	int n = size(root);
    	if(n > k){
    		return select(root.left, k);
    	}
    	else if(n < k){
    		return select(root.right, k-n-1);
    	}
    	else{
    		return root;
    	}
    }
    
    //rank of a given key : number of nodes in the subtree less than the key
    public static int rank(TreeNode root, int key){
    	if(root == null){
    		return 0;
    	}
    	
    	if(root.key > key){
    		return rank(root.left, key);
    	}
    	else if(root.key < key){
    		return 1+size(root.left)+rank(root.right, key);
    	}
    	else{
    		return size(root.left);
    	}
    }
    
    //check if a tree is a BST
    public static boolean isBST(TreeNode node){
    	return isBST(node, Integer.MAX_VALUE, Integer.MIN_VALUE);
    }
    private static boolean isBST(TreeNode node, int max, int min){
    	if(node == null){
    		return true;
    	}
    	
    	if(node.key >= max || node.key <= min){
    		return false;
    	}
    	
    	return isBST(node.left, node.key, min) && isBST(node.right, max, node.key);
    }
    
    //height of the subtree rooted at given node
    public static int height(TreeNode node){
    	if(node == null){
    		return -1;
    	}
    	
    	return 1+Math.max(height(node.left), height(node.right));
    }
    
    //binary search -- iterative
    public static TreeNode search(TreeNode root, int key){
    	if(root == null){
    		return null;
    	}
    	
    	while(root != null){
    		if(root.key == key){
    			return root;
    		}
    		else if(root.key > key){
    			root = root.left;
    		}
    		else{
    			root = root.right;
    		}
    	}
    	
    	return root;
    }
    
    //binary search -- recursive
    public static TreeNode search2(TreeNode root, int key){
    	if(root == null){
    		return root;
    	}
    	if(root.key > key){
    		return search(root.left, key);
    	}
    	else if (root.key > key){
    		return search(root.right, key);
    	}
    	else{
    		return root;
    	}    	
    }
    
    public static void PrintTreeInorder(TreeNode root){
    	if(root == null){
    		return;
    	}
    	
    	PrintTreeInorder(root.left);
    	System.out.print(" "+root.key);
    	PrintTreeInorder(root.right);
    }
}

 

How to find kth smallest without changing the original data structure?

We can do inorder traversal (recursive or iterative) and each time we find a root we decrease count that start from k. When we reach k=0, we have the kth smallest element.

public int kthSmallest(TreeNode root, int k) {
    TreeNode kth = MorrisInorderTraversal(root, k);
    return kth != null ? kth.val : -1;
}

public static TreeNode MorrisInorderTraversal(TreeNode root, int k){
	if(root == null){
		return null;
	}
	
	TreeNode cur = root;
	TreeNode pre = null;
	while(cur != null){
		//if no left subtree the visit right subtree right away after printing current node
		if(cur.left == null){
            k--;
            if(k == 0){
                return cur;
            }
			cur = cur.right;
		}
		else{
			//otherwise we will traverse the left subtree and come back to current 
			//node by using threaded pointer from predecessor of current node 
			//first find the predecessor of cur
			pre = cur.left;
			while(pre.right != null && pre.right != cur){
				pre = pre.right;
			}
			
			//threaded pointer not added - add it and go to left subtree to traverse
			if(pre.right == null){
				pre.right = cur;
				cur = cur.left;
			}
			else{
				//we traversed left subtree through threaded pointer and reached cur again
				//so revert the threaded pointer and print out current node before traversing right subtree
				pre.right = null;
				k--;
                if(k == 0){
                    return cur;
                }
				//now traverse right subtree
				cur = cur.right;
			}
		}
	}
	
	return null;
}