JZX轻语:简
LeetCode 802 - 找到最终的安全状态
发表于2024年05月09日
三色标记法+DFS,使用一个数组维护每个节点可能的三种状态:未遍历,正在遍历,遍历完毕。然后使用DFS计算每个节点是否为安全节点:如果其为终端节点,即也是安全节点;否则,遍历所有的路径,如果所有的路径经过的节点皆为安全节点,则该节点也为安全节点;如果碰上一个正在遍历的节点(在递归栈中),则存在环,说明该路径是走不到终端节点的。
这道题其实也是一个比较典型的拓扑排序应用题(毕竟终端节点定义为出度为0,很符合拓扑排序的特征),可以将这个图“翻转”起来,即调转所有的边的方向,然后应用拓扑排序将入度为0的节点添加到结果中并剥离,直至图中只剩下环。
class Solution:
def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]:
n = len(graph)
# 寻找终端节点
terminal_nodes = set()
for u, adj_list in enumerate(graph):
if not adj_list:
terminal_nodes.add(u)
safe_nodes = set(terminal_nodes)
# === 原来这就是三色标记法 ===
# 其实就是用来维护三种可能的状态嘛
visit_status = [0] * n # 0: 从来未访问 1: 正在访问中 2: 访问完毕
def dfs(u: int) -> bool:
""" 遍历节点u的所有路径, 最终返回u是不是安全节点 """
if u in terminal_nodes: # 已经到达终端节点, 返回True
return True
if visit_status[u] == 1: # 发现环, 说明这条路径是走不通的
return False
if visit_status[u] == 2: # 如果遍历的路径经过一个已经遍历的节点
return u in safe_nodes # 直接检查其是否安全节点即可
visit_status[u] = 1 # 标记正在访问
is_safe = True
for v in graph[u]:
if not dfs(v):
is_safe = False
break
visit_status[u] = 2 # 标记访问完毕
if is_safe:
safe_nodes.add(u)
return is_safe
for u in range(n):
if u not in terminal_nodes and visit_status[u] == 0:
dfs(u)
return sorted(safe_nodes)
一个更简洁一点的方法,去掉terminal_nodes
(反正遍历的时候,终端节点本就没有出路,直接走dfs逻辑就好了)和safe_nodes
,直接使用状态1
共同表示正在访问中/非安全节点,状态2
表示已经访问完毕且是安全节点。
class Solution:
def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]:
n = len(graph)
# === 原来这就是三色标记法 ===
# 其实就是用来维护三种可能的状态嘛
visit_status = [0] * n # 0: 从来未访问 1: 正在访问中 or 非安全节点 2: 安全节点
def dfs(u: int) -> bool:
""" 遍历节点u的所有路径, 最终返回u是不是安全节点 """
if visit_status[u] == 1:
return False
if visit_status[u] == 2: # 如果遍历的路径经过一个已经遍历的节点
return True
visit_status[u] = 1 # 标记正在访问
for v in graph[u]:
if not dfs(v):
# 直接返回就行了
# 此时的状态1表示为非安全节点了!
return False
visit_status[u] = 2 # 标记访问完毕
return True
ans = []
for u in range(n):
if visit_status[u] == 0:
dfs(u)
if visit_status[u] == 2:
ans.append(u)
return ans
应用拓扑排序的版本,注意寻找入度为0的节点不要循环,会超时。维护一个队列来保存入度为0的节点。
from collections import deque
class Solution:
def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]:
n = len(graph)
rev_graph = [[] for _ in range(n)]
for u, adj_list in enumerate(graph):
for v in adj_list:
rev_graph[v].append(u)
q = deque()
in_deg = [0] * n
for u in range(n):
in_deg[u] = len(graph[u])
if in_deg[u] == 0:
q.append(u)
ans = set()
while q:
zero_deg_node = q.popleft()
ans.add(zero_deg_node)
for v in rev_graph[zero_deg_node]:
in_deg[v] -= 1
if not in_deg[v]:
q.append(v)
return sorted(ans)