JZX轻语:简
LeetCode 2385 - 感染二叉树需要的总时间
发表于2024年04月24日
有点难度的二叉树递归,节点的感染路径有三个方向:往下感染左子树,右子树以及往上感染祖先。可以通过一次遍历得到结果:使用一个变量ans
维护当前的感染最长路径,按后序遍历树节点,如果遇到了start
节点,则先将ans
设置为其子树的高度。否则,对于某个节点,如果其左子树包含start
节点,则可能的最长感染路径为start节点->...[从左子树往上感染]->本节点->...[往下感染]->右子树的最深节点
,右子树亦然,不断更新ans
,直至处理到root
节点。
class Solution:
def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
ans = 0
def dfs(node: Optional[TreeNode]) -> (int, int):
# 返回值: 该子树的高, 该子树的根距离start节点的距离(如果start节点不在该子树内,则为-1)
nonlocal ans
if node is None:
return 0, -1
left_depth, source_dis_l = dfs(node.left)
right_depth, source_dis_r = dfs(node.right)
if node.val == start:
# 碰到start节点
ans = max(left_depth, right_depth)
return max(left_depth, right_depth) + 1, 0
if source_dis_l != -1:
# 左子树中包含start节点, 则可能的最长感染路径为 start节点 -> ... -> 本节点 -> 右子树最深节点
ans = max(ans, right_depth + 1 + source_dis_l)
elif source_dis_r != -1:
# 右子树包含start节点, 则可能的最长感染路径为 start节点 -> ... -> 本节点 -> 左子树最深节点
ans = max(ans, left_depth + 1 + source_dis_r)
# 计算本节点距离start节点的距离
source_dis = max(source_dis_l, source_dis_r)
if source_dis != -1:
source_dis += 1 # 加上本节点
return 1 + max(left_depth, right_depth), source_dis
dfs(root)
return ans