# 222. Count Complete Tree Nodes

Given a complete binary tree, count the number of nodes.

Note:

Definition of a complete binary tree from Wikipedia: In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2^h nodes inclusive at the last level h.

Example:

Input: 
    1
   / \
  2   3
 / \  /
4  5 6

Output: 6

# Solution

Approach 1: binary search -- find the first node in the bottom-most level where nodes begin to null out. To translate the serial index of the bottom-most level to a path, encode the index in binary and go left at a '0' and go right at an '1'. Time: log(2^h) = O(h) finds, each find takes O(h), total O(h^2).

Approach 2: recursive solution -- total node count is the sum of node count in left, and right, plus 1 (root itself). Can use left_height and right_height to preserve what's already been calculated before.

Approach 3: iterative solution -- focus on the num of levels of the right subtree (counting along the left side) to decide whether the unfilled sections start from the right or left subtree.

# Code (Python)

Approach 1:

    def countNodes(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        if not root:
            return 0
        num_levels = 1
        node = root
        while node.left:
            node = node.left
            num_levels += 1
        left, right = 1, 2**(num_levels - 1) # find the first node where level(node) is smaller than num_levels
        while left < right:
            mid = (left + right) // 2
            if self._find_levels(mid, num_levels, root) == num_levels:
                left = mid + 1 # if looking for last node where level(node) == num_levels, this can go into an infinite loop when left = mid, because there are times when left == mid in the first place
            else:
                right = mid
        return 2 ** (num_levels - 1) - 1 + left
    
    def _find_levels(self, index, total_levels, root):
        # to translate the serial index of the bottom-most level to a path, encode the index in binary and go left at a '0' and go right at an '1'
        path = str(bin(index))[2:]
        while len(path) < total_levels - 1:
            path = '0' + path
        node = root
        num_levels = 1
        for char in path:
            if char == '0':
                node = node.left
            else:
                node = node.right
            if not node:
                return num_levels
            num_levels += 1
        return num_levels

Approach 2:

    def countNodes(self, root):
        if not root:
            return 0
        return self._count_nodes(root, -1, -1)
    
    def _count_nodes(self, node, left_height, right_height):
        if left_height == -1:
            left_height = 0
            n = node
            while n:
                left_height += 1
                n = n.left
        if right_height == -1:
            right_height = 0
            n = node
            while n:
                right_height += 1
                n = n.right
        if left_height == right_height:
            # tree is full
            return (1 << left_height) - 1
        return 1 + self._count_nodes(node.left, left_height - 1, -1) + self._count_nodes(node.right, -1, right_height - 1)

Approach 3:

    def countNodes(self, root):
        total = 0
        max_height = self._height_from_left(root)
        height_left = max_height - 1
        while root:
            if height_left == self._height_from_left(root.right):
                # unfilled section starts from the right subtree
                total += 1 << (height_left + 1)
                root = root.right
            else:
                # unfilled section starts from the left subtree
                total += 1 << height_left
                root = root.left
            height_left -= 1
        return total
    
    def _height_from_left(self, node):
        if not node:
            return -1
        return self._height_from_left(node.left) + 1

# Code (C++)

Approach 1:

Approach 2:

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countNodes(TreeNode* root) {
        if (root == NULL)
            return 0;
        TreeNode *node = root;
        int leftHeight = 0;
        while (node)
        {
            leftHeight++;
            node = node->left;
        }
        node = root;
        int rightHeight = 0;
        while (node)
        {
            rightHeight++;
            node = node->right;
        }
        if (leftHeight == rightHeight)
            return pow(2, leftHeight) - 1;
        else
            return 1 +
                countNodes(root->left) +
                countNodes(root->right);
    }
};