"""
The node number in last level: 1 ~ 2^d-1
Use 2 binary search, first search if idx exist in last level, second in 'exist function'
Time O(2d)
1. find depth of root
2. return 1 if d == 0
3. Use 'exist' function to find if left idx exist
4. Implement exist(idx, root, d)
"""
class Solution:
def countNodes(self, root: TreeNode) -> int:
if not root: return 0
d = self.getDepth(root)
if d == 0: return 1
left, right = 1, 2**d
while left < right:
mid = (left + right)//2
if self.exist(mid, root, d):
left = mid + 1
else:
right = mid
return (2**d-1) + left
def getDepth(self, node):
d = 0
while node.left:
node = node.left
d += 1
return d
def exist(self, idx, node, d):
left, right = 0, 2**d-1
for _ in range(d):
mid = (left + right)//2
if idx <= mid:
node = node.left
right = mid
else:
node = node.right
left = mid
return node is not None