# 230. Kth Smallest Element in a BST

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

# Solution

Approach 1: iterative -- inorder traversal.

Approach 2: recursive -- can either take O(1) space or O(k) space.

Follow up: let each node keep an additional value of how many nodes in its subtree are less than itself (i.e. num nodes in its left subtree). Or use an external datastructure like hashtable. Inserts and deletes need to update that value too.

# Code (Python)

Approach 1:

class Solution:
    def kthSmallest1(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        # iterative
        stack = []
        node = root
        while node:
            stack.append(node)
            node = node.left
        while stack:
            node = stack.pop()
            k -= 1
            if k == 0:
                return node.val
            if node.right:
                node = node.right
                while node:
                    stack.append(node)
                    node = node.left

Approach 2:

    def kthSmallest(self, root, k):
        # recursive
        values = []
        self._inorder(root, values, k)
        return values[k-1]
    
    def _inorder(self, node, values, k):
        if not node or len(values) >= k: # early cutoff
            return
        if node.left:
            self._inorder(node.left, values, k)
        values.append(node.val)
        if node.right:
            self._inorder(node.right, values, k)

Follow up:

 class TreeNode:
     def __init__(self, x):
         self.val = x
         self.left = None
         self.right = None
         self.nodes_less_than = 0

    def kth_smallest_modified(self, root, k):
        node = root
        while True:
            if node.nodes_less_than == k-1:
                return node.val
            elif node.nodes_less_than > k-1:
                node = node.left
            else:
                k -= node.nodes_less_than + 1
                node = node.right

# Code (C++)

Approach 1:

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
// Iteration
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        stack<TreeNode*> st;
        TreeNode *node = root;
        int count = 0;
        while (node || !st.empty())
        {
            if (node)
            {
                st.push(node);
                node = node->left;
            }
            else
            {
                node = st.top();
                st.pop();
                count++;
                if (count == k)
                    break;
                node = node->right;
            }
        }
        return node->val;
    }
};

Approach 2:

// Recursion
class Solution {
    int count = 0;
public:
    int kthSmallest(TreeNode* root, int k) {
        if (root == NULL)
            return 0;
        int leftVal = kthSmallest(root->left, k);
        count++;
        if (count == k)
            return root->val;
        if (count < k)
            return kthSmallest(root->right, k);
        return leftVal;
    }
};
class Solution {
    int kth;
    int count;
private:
    void getKthSmallest(TreeNode* root, int k) {
        if (root == NULL)
            return;
        getKthSmallest(root->left, k);
        count++;
        if (count == k)
        {
            kth = root->val;
            return;
        }
        if (count < k)
        {
            getKthSmallest(root->right, k);
        }
    }
public:
    int kthSmallest(TreeNode* root, int k) {
        count = 0;
        getKthSmallest(root, k);
        return kth;
    }
};