938. Range Sum of BST

Easy

Problem:

Transform each node in the BST to the sum of all nodes with values greater than the current node's value.

Input: root = [10,5,15,3,7,null,18], low = 7, high = 15
Output: 32
Explanation: Nodes 7, 10, and 15 are in the range [7, 15]. 7 + 10 + 15 = 32.

https://leetcode.com/problems/range-sum-of-bst/

Solution:

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
        result: int = 0
        if not root:
            return None
        
        def dfs(node):
            nonlocal result
            
            if not node:
                return

            dfs(node.left)
            dfs(node.right)

            if node.val < low or node.val > high:
                return

            elif (node.val >= low or low > node.val) and node.val <= high:
                result += node.val
                
            return

        dfs(root)
        return result

Optimization: Pruning

The function dfs uses the properties of a BST to optimize the search. If the current node's value is less than low, then we know that all the values in its left subtree will also be less than low, so we only need to search the right subtree. Conversely, if the current node's value is greater than high, then all the values in its right subtree will also be greater than high, so we only need to search the left subtree.

  1. If the current node's value is less than low, then the function returns the sum of the values in the right subtree (dfs(node.right)) because all values in the left subtree will be less than low and thus out of the desired range.

  2. If the current node's value is greater than high, then the function returns the sum of the values in the left subtree (dfs(node.left)) because all values in the right subtree will be greater than high and thus out of the desired range.

  3. If the current node's value is within the range [low, high], then the function returns the sum of the current node's value and the values in both its left and right subtrees (node.val + dfs(node.left) + dfs(node.right)).

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
        def dfs(node):
            if not node:
                return 0

            if node.val < low :
                return dfs(node.right)

            elif node.val > high: 
                return dfs(node.left)

            return node.val + dfs(node.left) + dfs(node.right)

        return dfs(root)

Iteration: BFS/DFS

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
        stack, sum = [root], 0
        
        # DFS: pop(), BFS: pop(0)
        while stack:
            node = stack.pop()
            if node:
                if node.val > low:
                    stack.append(node.left)
                if node.val < high:
                    stack.append(node.right)
                if low <= node.val <= high:
                    sum += node.val
        return sum

Last updated