## # 230. Kth Smallest Element in a BST

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example:

Input: root = [5,3,6,2,4,null,null,1], k = 3
5
/ \
3   6
/ \
2   4
/
1
Output: 3

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

### # Solution

Approach 1: iterative -- inorder traversal.

Approach 2: recursive -- can either take O(1) space or O(k) space.

Follow up: let each node keep an additional value of how many nodes in its subtree are less than itself (i.e. num nodes in its left subtree). Or use an external datastructure like hashtable. Inserts and deletes need to update that value too.

### # Code (Python)

Approach 1:

class Solution:
def kthSmallest1(self, root, k):
"""
:type root: TreeNode
:type k: int
:rtype: int
"""
# iterative
stack = []
node = root
while node:
stack.append(node)
node = node.left
while stack:
node = stack.pop()
k -= 1
if k == 0:
return node.val
if node.right:
node = node.right
while node:
stack.append(node)
node = node.left

Approach 2:

def kthSmallest(self, root, k):
# recursive
values = []
self._inorder(root, values, k)
return values[k-1]

def _inorder(self, node, values, k):
if not node or len(values) >= k: # early cutoff
return
if node.left:
self._inorder(node.left, values, k)
values.append(node.val)
if node.right:
self._inorder(node.right, values, k)

class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
self.nodes_less_than = 0

def kth_smallest_modified(self, root, k):
node = root
while True:
if node.nodes_less_than == k-1:
return node.val
elif node.nodes_less_than > k-1:
node = node.left
else:
k -= node.nodes_less_than + 1
node = node.right

### # Code (C++)

Approach 1:

/**
* Definition for a binary tree node.
* struct TreeNode {
*     int val;
*     TreeNode *left;
*     TreeNode *right;
*     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
// Iteration
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
stack<TreeNode*> st;
TreeNode *node = root;
int count = 0;
while (node || !st.empty())
{
if (node)
{
st.push(node);
node = node->left;
}
else
{
node = st.top();
st.pop();
count++;
if (count == k)
break;
node = node->right;
}
}
return node->val;
}
};

Approach 2:

// Recursion
class Solution {
int count = 0;
public:
int kthSmallest(TreeNode* root, int k) {
if (root == NULL)
return 0;
int leftVal = kthSmallest(root->left, k);
count++;
if (count == k)
return root->val;
if (count < k)
return kthSmallest(root->right, k);
return leftVal;
}
};
class Solution {
int kth;
int count;
private:
void getKthSmallest(TreeNode* root, int k) {
if (root == NULL)
return;
getKthSmallest(root->left, k);
count++;
if (count == k)
{
kth = root->val;
return;
}
if (count < k)
{
getKthSmallest(root->right, k);
}
}
public:
int kthSmallest(TreeNode* root, int k) {
count = 0;
getKthSmallest(root, k);
return kth;
}
};