938. Range Sum of BST
Easy
Problem:
Transform each node in the BST to the sum of all nodes with values greater than the current node's value.

Input: root = [10,5,15,3,7,null,18], low = 7, high = 15
Output: 32
Explanation: Nodes 7, 10, and 15 are in the range [7, 15]. 7 + 10 + 15 = 32.
https://leetcode.com/problems/range-sum-of-bst/
Solution:
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
result: int = 0
if not root:
return None
def dfs(node):
nonlocal result
if not node:
return
dfs(node.left)
dfs(node.right)
if node.val < low or node.val > high:
return
elif (node.val >= low or low > node.val) and node.val <= high:
result += node.val
return
dfs(root)
return result
Optimization: Pruning
The function dfs
uses the properties of a BST to optimize the search. If the current node's value is less than low
, then we know that all the values in its left subtree will also be less than low
, so we only need to search the right subtree. Conversely, if the current node's value is greater than high
, then all the values in its right subtree will also be greater than high
, so we only need to search the left subtree.
If the current node's value is less than
low
, then the function returns the sum of the values in the right subtree (dfs(node.right)
) because all values in the left subtree will be less thanlow
and thus out of the desired range.If the current node's value is greater than
high
, then the function returns the sum of the values in the left subtree (dfs(node.left)
) because all values in the right subtree will be greater thanhigh
and thus out of the desired range.If the current node's value is within the range
[low, high]
, then the function returns the sum of the current node's value and the values in both its left and right subtrees (node.val + dfs(node.left) + dfs(node.right)
).
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
def dfs(node):
if not node:
return 0
if node.val < low :
return dfs(node.right)
elif node.val > high:
return dfs(node.left)
return node.val + dfs(node.left) + dfs(node.right)
return dfs(root)
Iteration: BFS/DFS
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def rangeSumBST(self, root: Optional[TreeNode], low: int, high: int) -> int:
stack, sum = [root], 0
# DFS: pop(), BFS: pop(0)
while stack:
node = stack.pop()
if node:
if node.val > low:
stack.append(node.left)
if node.val < high:
stack.append(node.right)
if low <= node.val <= high:
sum += node.val
return sum
Last updated