JZX轻语:简

LeetCode 3067 - 在带权树网络中统计可连接服务器对数目

发表于2024年06月04日

#图论 #树 #DFS #BFS

其实没啥好的做法,需要枚举每个节点,计算其作为根节点时,每棵子树(每个方向)中路径长度为signalSpeed的倍数的边的数量。然后经过该节点的服务器对数目就是以该节点作为根时,不同子树中满足上述条件的边两两乘积的和。可以用BFSDFS来实现。

两两乘积的和,可以利用后缀和来优化二次循环:先计算出各个方向结果的总和,然后每次遍历的时候减去当前的值,就可以得到剩下未遍历元素的总和,然后再乘起来就可以了。

使用BFS的做法:

from collections import deque


class Solution:
    def countPairsOfConnectableServers(self, edges: List[List[int]], signalSpeed: int) -> List[int]:
        n = len(edges) + 1
        adj_list = [[] for _ in range(n)]
        for u, v, w in edges:
            adj_list[u].append((v, w))
            adj_list[v].append((u, w))

        def bfs(root: int) -> int:
            q = collections.deque()
            results = [0] * len(adj_list[root])

            for i, (child, weight) in enumerate(adj_list[root]):
                q.append((child, weight, root, i))

            while q:
                node, dis, parent, direction = q.popleft()
                if not (dis % signalSpeed):
                    results[direction] += 1
                for child, weight in adj_list[node]:
                    if child == parent:
                        continue
                    q.append((child, dis + weight, node, direction))

            ans_for_root = 0
            # 利用后缀和来优化二次循环
            sum_ = sum(results)
            for result in results:
                sum_ -= result
                ans_for_root += result * sum_
            return ans_for_root

        return [
            bfs(node) for node in range(n)
        ]

一开始使用DFS的做法。

class Solution:
    def countPairsOfConnectableServers(self, edges: List[List[int]], signalSpeed: int) -> List[int]:
        n = len(edges) + 1
        adj_list = [[] for _ in range(n)]
        for u, v, w in edges:
            adj_list[u].append((v, w))
            adj_list[v].append((u, w))

        def dfs(node: int, parent: int, prev_dis: int) -> int:
            nonlocal ans

            res = int(prev_dis % signalSpeed == 0)
            for c, w in adj_list[node]:
                if c != parent:
                    res += dfs(c, node, prev_dis + w)

            return res

        ans = [0] * n
        for p in range(n):
            child_res = []
            for child, weight in adj_list[p]:
                child_res.append(dfs(child, p, weight))
            for i in range(len(child_res)):
                for j in range(i + 1, len(child_res)):
                    ans[p] += child_res[i] * child_res[j]

        return ans

闪念标签:LC

题目链接:https://leetcode.cn/problems/count-pairs-of-connectable-servers-in-a-weighted-tree-network/