JZX轻语:简

LeetCode 699 - 掉落的方块

发表于2024年07月28日

#线段树

线段树的应用题,由于之前没系统实现过线段树,所以这次从头学了一遍线段树的简单实现和应用。该题目就是在一个二维的坐标轴上,不断添加方块,并将每次添加后所有方块的最大高度添加到结果中。本质上,每次添加边长为sideLength的方块时,首先需要找到该方块所在x轴区间[left, left + sideLength - 1]的最大高度h,然后将该区间的最大高度更新为h + sideLength。不难得知此类区间更新+查询类型题目可使用线段树解决,每个节点维护一个区间的最大高度,在每次添加方块的时候,首先查询该区间的最大高度,然后更新该区间的最大高度,并将总的最大高度添加到结果中。

该题目之所以为Hard,是因为题意中的数据量较大,且left的范围较广,如果使用传统的基于数组的实现方法会导致内存溢出,此时可考虑使用哈希表来存储线段树的节点,以减少内存的使用(因为暂时用不到的节点可以先不存储在哈希表中,节点按需创建)。此外,还需要用到延迟标记来实现区间更新的懒惰更新,避免区间范围过大时导致递归很深且内存占用过大。

class SegTree:
    def __init__(self, n: int):
        self._n = n
        self._tree = {}  # 存储每个节点对应区间的最大值
        self._delay = {}  # 延迟更新标记

    def _delay_update(self, cur_idx: int):
        """ 处理cur_idx对应区间的节点的延迟更新标记
        更新两个子节点的数据,并将更新标记下发到两个子节点中(当遍历到子节点的时候再继续往下下发)
        """
        lc_idx = cur_idx * 2 + 1
        rc_idx = cur_idx * 2 + 2
        # 当前节点有延迟更新标记(意味着子节点数据还没有更新)
        # 更新两个子节点的数据
        # 并将延迟标记下发到孩子节点中
        self._tree[lc_idx] = self._delay[lc_idx] = self._delay[cur_idx]
        self._tree[rc_idx] = self._delay[rc_idx] = self._delay[cur_idx]
        # 下发完毕, 删除本节点的延迟更新标记
        del self._delay[cur_idx]

    def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int:
        """ 递归搜索以获取区间[l, r]的最大值

        :param l 待搜索区间的左侧
        :param r 待搜索区间的右侧
        :param s 当前区间左侧
        :param t 当前区间右侧
        :param cur_idx 当前区间的节点序号
        """
        if l <= s and t <= r:  # 当前区间[s, t]包含在待搜索区间的[l, r]中,直接返回当前节点所保存的区间[s, t]最大值
            return self._tree.get(cur_idx, 0)
        if t < l or s > r:  # [s, t]和[l, r]没有交集, 提前return
            return 0

        if self._delay.get(cur_idx, 0):
            # 有延迟更新标记, 处理下
            self._delay_update(cur_idx)

        lc_idx = cur_idx * 2 + 1
        rc_idx = cur_idx * 2 + 2
        mid = (s + t) // 2

        # 递归下去搜索
        return max(
            self._search_helper(l, r, lc_idx, s, mid),
            self._search_helper(l, r, rc_idx, mid + 1, t)
        )

    def search(self, l: int, r: int) -> int:
        return self._search_helper(l, r, 0, 0, self._n - 1)

    def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int):
        """ 递归更新区间[l, r]的最大值为val

        :param l 待更新区间的左侧
        :param r 待更新区间的右侧
        :param val 更新后的值
        :param s 当前区间左侧
        :param t 当前区间右侧
        :param cur_idx 当前区间的节点序号
        """
        if t < l or s > r:  # [s, t]和[l, r]没有交集, 提前return
            return

        if l <= s and t <= r:
            # [s, t]包含在[l, r]中,则先**仅对该节点**进行更新操作,并设置延迟更新标记
            self._tree[cur_idx] = self._delay[cur_idx] = val
            return

        if self._delay.get(cur_idx, 0):
            # 更新该节点前,如果该节点有延迟更新标记(上轮更新导致的), 先处理上一轮的延迟更新后,再进行本轮的更新
            self._delay_update(cur_idx)
            
        lc_idx = cur_idx * 2 + 1
        rc_idx = cur_idx * 2 + 2
        mid = (s + t) // 2
        
        self._update_helper(l, r, val, lc_idx, s, mid)
        self._update_helper(l, r, val, rc_idx, mid + 1, t)
        # 更新操作完毕后,也需要更新本节点的值(因为区间有交集),做法为取两个子节点中的最大值即可
        self._tree[cur_idx] = max(self._tree.get(lc_idx, 0), self._tree.get(rc_idx, 0))

    def update(self, l: int, r: int, val: int):
        self._update_helper(l, r, val, 0, 0, self._n - 1)

    def root_val(self):
        """ 返回整个数组的最大值 """
        return self._tree[0]


class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        r_bound = max(left + side for left, side in positions)
        seg_tree = SegTree(r_bound + 1)
        ans = []
        for left, side in positions:
            cur_max_height = seg_tree.search(left, left + side - 1)
            seg_tree.update(left, left + side - 1, cur_max_height + side)
            ans.append(seg_tree.root_val())
        return ans

MLE的做法,因为没有做延迟标记,导致递归深度过大,内存占用过大。

class SegTree:
    def __init__(self, n: int):
        self._n = n
        self._tree = {}

    def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int:
        if l <= s and t <= r:
            return self._tree.get(cur_idx, 0)
        if t < l or s > r:
            return 0

        mid = (s + t) // 2
        return max(
            self._search_helper(l, r, 2 * cur_idx + 1, s, mid),
            self._search_helper(l, r, 2 * cur_idx + 2, mid + 1, t)
        )

    def search(self, l: int, r: int) -> int:
        return self._search_helper(l, r, 0, 0, self._n - 1)

    def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int):
        if t < l or s > r:
            return
        if s == t:  # !!! 注意这里直至叶子节点才更新
            self._tree[cur_idx] = val
            return
        mid = (s + t) // 2
        self._update_helper(l, r, val, 2 * cur_idx + 1, s, mid)
        self._update_helper(l, r, val, 2 * cur_idx + 2, mid + 1, t)
        self._tree[cur_idx] = max(self._tree.get(cur_idx * 2 + 1, 0), self._tree.get(cur_idx * 2 + 2, 0))

    def update(self, l: int, r: int, val: int):
        self._update_helper(l, r, val, 0, 0, self._n - 1)

    def root_val(self):
        return self._tree[0]


class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        r_bound = max(left + side for left, side in positions)
        seg_tree = SegTree(r_bound + 1)
        ans = []
        for left, side in positions:
            cur_max_height = seg_tree.search(left, left + side - 1)
            seg_tree.update(left, left + side - 1, cur_max_height + side)
            ans.append(seg_tree.root_val())
        return ans

闪念标签:LC

题目链接:https://leetcode.cn/problems/falling-squares/