[Leetcode] Binary Tree Path Problems

Problem Introduction: https://www.youtube.com/watch?v=zIkDfgFAg60

Recursive Solution Template:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
<> res = 0;
public <result> solution(TreeNode) root{
if (root == null) return ..
helper(root, "");
return res;
}
private void helper(TreeNode root, <>path) {
if (root.left == null && root.right == null) { // if reach leaf node
res.add(path);
return;
}
if (root.left != null) helper(root.left, path +...);
if (root.right != null) helper(root.right, path + ...);
}

129. Sum Root to Leaf Numbers

[Recursive DFS Solution] Set a helper recursion function that, accepts two inputs: root and prevValue. The prevVale would track the parent’s nodes accumulative sum by root.val + prevValue * 10, and pass as input to children nodes.

If the recursion reach the leaf nodes, add the sum to the global result and return. Enter the next recursion.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
int res = 0;
public int sumNumbers(TreeNode root) {
if (root == null) return res;
helper(root, 0);
return res;
}

private void helper(TreeNode root, int prevValue){
if(root.left == null && root.right == null) {
res += root.val + 10 * prevValue;
return;
}
if (root.left != null) helper(root.left, root.val + prevValue * 10);
if (root.right != null) helper(root.right, root.val + prevValue * 10);
}
}

257. Binary Tree Paths

[Solution] Construct a recursive helper function that take path as input.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public List<String> binaryTreePaths(TreeNode root) {
List<String> res = new ArrayList<>();
if (root == null) return res;
helper(root, res, "");
return res;
}

private void helper(TreeNode root, List<String> res, String path) {
// if reach leaf
if (root.left == null && root.right == null) {
res.add(path + Integer.toString(root.val));
return;
}
if (root.left != null) helper(root.left, res, path + Integer.toString(root.val) + "->");
if (root.right != null) helper(root.right, res, path + Integer.toString(root.val) + "->");
}
}

113. Path Sum II

[Solution]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
List<List<Integer>> res = new ArrayList<>();

public List<List<Integer>> pathSum(TreeNode root, int sum) {
if (root == null) return res;
Stack<Integer> path = new Stack<>();
helper(root, path, sum);
return res;
}

private void helper(TreeNode root, Stack<Integer> path, int sum) {
path.push(root.val);
if (root.left == null && root.right == null) {
if (root.val == sum) {
res.add(new ArrayList<Integer>(path)); // convert stack to ArrayList
}
}
if (root.left != null) helper(root.left, path, sum - root.val);
if (root.right != null) helper(root.right, path, sum - root.val);
path.pop(); // why do we need to remove the last item? - go back to previous state
}
}

437. Path Sum III

Leetcode: https://leetcode.com/problems/path-sum-iii/

Solution: Use a dictionary to store accumalative sum to this node, key is the sum, and value is the count of the sum. At each node, test if curSum - sum is in dict, if so, add count to global count.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
var pathSum = function(root, sum) {
let cnt = 0;
let map = new Map(); // {sum: count}
map.set(0, 1);
const _sum = (node, curSum) => {
if (!node) { return }
curSum += node.val;
if (map.has(curSum - sum)) {
cnt += map.get(curSum - sum);
}
if (map.has(curSum)) {
map.set(curSum, map.get(curSum) + 1);
} else {
map.set(curSum, 1);
}
_sum(node.left, curSum);
_sum(node.right, curSum);
map.set(curSum, map.get(curSum) - 1)
}
_sum(root, 0);
return cnt;
};

494. Target Sum

Leetcode: https://leetcode.com/problems/target-sum/

Solution: Similar to above path sum, use a dictionary to remember the count of previous sums. If in the current layer there exist sum == target, then add the count of prev sum to global counter.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def findTargetSumWays(self, nums: List[int], S: int) -> int:
self.cnt = 0
def backtrack(index, prevSums):
curSums = collections.defaultdict(int)
if index == len(nums):
return
count = 0
for sm in prevSums:
curSums[sm + nums[index]] += prevSums[sm]
curSums[sm - nums[index]] += prevSums[sm]
if sm + nums[index] == S:
count += prevSums[sm]
if sm - nums[index] == S:
count += prevSums[sm]
self.cnt = count
backtrack(index + 1, curSums)
d = {0:1}
backtrack(0, d)
return self.cnt

Path through root

687. Longest Univalue Path

Leetcode: https://leetcode.com/problems/longest-univalue-path/description/
Solution: https://www.youtube.com/watch?v=asihnVxQuL4

[Solution]

Use a global variable len to keep track of the maximum length of the nodes, which would be the sum(max(leftNode), max(rightNode)).

Since the path might go through the root, if so, we need to pass to the root max(left, right) + 1 when returning, rather than max(left + right) + 1

If the path is disconnected, we set return value to 0.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
int len = 0;
public int longestUnivaluePath(TreeNode root) {
if (root == null) return 0;
longest(root, root.val);

return len;
}
private int longest(TreeNode root, int val) { // val is parent.val, root is children node
if (root == null) return 0;

int maxLeft = longest(root.left, root.val);
int maxRight = longest(root.right, root.val);

len = Math.max(len, maxLeft + maxRight);
// if value equals to its parent
if (root.val == val) return Math.max(maxLeft, maxRight) + 1;
// if the link disconnect, return 0 (1-2-1-1)
return 0;
}
}

124. Binary Tree Maximum Path Sum

Leetcode: https://leetcode.com/problems/binary-tree-maximum-path-sum/description/

[Solution]

  • Global variable max to keep the answer
  • When reach to the leaf node, update max = leaf.val + 0 + 0
  • When pass the value to upper level, use Math.max(left, right) + node.val to pass either left or right path to the upper recursion.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Solution in discussion(cleaner)
public class Solution {
int maxValue;

public int maxPathSum(TreeNode root) {
maxValue = Integer.MIN_VALUE;
maxPathDown(root);
return maxValue;
}

private int maxPathDown(TreeNode node) {
if (node == null) return 0;
int left = Math.max(0, maxPathDown(node.left)); // use 0 to exclude negative values
int right = Math.max(0, maxPathDown(node.right));
maxValue = Math.max(maxValue, left + right + node.val);
return Math.max(left, right) + node.val;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def maxPathSum(self, root: TreeNode) -> int:
if not root:
return 0
res = -float('inf')
def helper(node):
nonlocal res # !! note: the use of nonlocal
if not node:
return 0
left_offer = max(helper(node.left), 0)
right_offer = max(helper(node.right), 0)
through = node.val + left_offer + right_offer
curMax = node.val + max(left_offer, right_offer)
res = max(res, max(through, curMax))
return curMax
helper(root)
return res

543. Diameter of Binary Tree

Leetcode: https://leetcode.com/problems/diameter-of-binary-tree/description/
Compute the longest path between any nodes in a tree.

[Solution]

  • Use a global variable maxLen
  • Update maxLen in every recursion, compute the sum of left and right nodes + 1 as current max length
  • When returning value to the uppper recusion, only return either right or left max length
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution {
int maxLen = 0;
public int diameterOfBinaryTree(TreeNode root) {
if (root == null) return maxLen;
helper(root);
return maxLen;
}
private int helper(TreeNode root) {
int left = 0;
int right = 0;
if (root.left == null && root.right == null) return 0; // return 0 when leaf node

if (root.left != null) left = helper(root.left) + 1;
if (root.right != null) right = helper(root.right) + 1;
maxLen = Math.max(left + right, maxLen);

return Math.max(left, right);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def __init__(self):
self.max = 0

def diameterOfBinaryTree(self, root: TreeNode) -> int:
if not root:
return 0
def longestPath(node): # node counts
if not node:
return 0
left = longestPath(node.left)
right = longestPath(node.right)
# not go up
self.max = max(self.max, left + right + 1)
# go up
return max(left + 1, right + 1)
return max(longestPath(root), self.max) - 1

337. House Robber III

Leetcode: https://leetcode.com/problems/house-robber-iii/

Video: https://www.youtube.com/watch?v=-i2BFAU25Zk

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
var rob = function(root) {
// each node has two status: not rob current, rob current. [not, rob]
const robNode = (node) => {
if (!node) { return [0, 0]; }
let cur = [0, 0];
let left = robNode(node.left)
let right = robNode(node.right)
// If not rob current
cur[0] = Math.max(left[0], left[1]) + Math.max(right[0], right[1]);
// If rob current
cur[1] = node.val + left[0] + right[0];
return cur;
}
let res = robNode(root);
return Math.max(...res);
};

863. All Nodes Distance K in Binary Tree

Leetcode: https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/

Note:

  • Do a dfs to form an undirected graph from all nodes first
  • Then do a bfs to find the nodes in certain distance
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def distanceK(self, root, target, K):
graph = collections.defaultdict(list) # {1: [0,8], ...}
def dfs(node, prev):
if not node:
return
if prev:
graph[prev.val].append(node.val)
graph[node.val].append(prev.val)
dfs(node.left, node)
dfs(node.right, node)
dfs(root, None)
self.res = []
def bfs(node, dist, visited):
if dist == 0:
self.res.append(node)
return
visited.add(node)
for nb in graph[node]:
if nb not in visited:
bfs(nb, dist - 1, visited)
bfs(target.val, K, set())
return self.res