310. Minumum Height Trees

Medium

Problem:

Given the number of nodes and an undirected graph, return a list of roots that results in the tree having the minimum height.

Input: n = 4, edges = [[1,0],[1,2],[1,3]]
Output: [1]
Explanation: As shown, the height of the tree is 1 when the root is the node with label 1 which is the only MHT.

https://leetcode.com/problems/minimum-height-trees/

Solution:

By progressively removing leaf nodes, the remaining value will be the most central one. When this value is used as the root, the tree will have the minimum height.

  1. Since the graph is undirected, both parent and child nodes in the tree can alternate positions. Hence, both nodes are inserted bidirectionally into a graph dictionary variable named graph.

  2. Identify the leaf nodes and add them to the leaves. A leaf node in the graph is a node whose key has only one value.

  3. Keep removing leaf nodes until only the root remains. When doing so, use pop() to remove from the graph dictionary and also remove the linked value. This is because we created the graph bidirectionally since it is an undirected graph.

  4. Lastly, if the number of remaining nodes is odd, there will be only one root. However, if even, there can be two roots, so the while loop continues until the last two nodes remain.

Why Stop at Two Nodes?

Consider a scenario where the tree is a straight line, i.e., a linked list. For an even number of nodes, the center will have two nodes, while for an odd number of nodes, the center will have one node. So, when you're left with only two or fewer nodes, you have found your MHTs.

For example, consider the tree nodes as [0-1-2-3-4]. Nodes 2 and 3 are the MHTs.

In the tree nodes [0-1-2-3-4-5], node 3 would be the MHT.

class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        if n <= 1:
            return [0]

        graph = collections.defaultdict(list)
        for i, j in edges:
            graph[i].append(j)
            graph[j].append(i)

        # First leaf node
        leaves = []
        for i in range(n + 1):
            if len(graph[i]) == 1:
                leaves.append(i)

        # Remove leaf nodes
        while n > 2:
            n -= len(leaves)
            new_leaves = []
            for leaf in leaves:
                neighbor = graph[leaf].pop()
                graph[neighbor].remove(leaf)

                if len(graph[neighbor]) == 1:
                    new_leaves.append(neighbor)
            leaves = new_leaves

        return leaves

Last updated