JZX轻语:简
LeetCode 699 - 掉落的方块
发表于2024年07月28日
线段树的应用题,由于之前没系统实现过线段树,所以这次从头学了一遍线段树的简单实现和应用。该题目就是在一个二维的坐标轴上,不断添加方块,并将每次添加后所有方块的最大高度添加到结果中。本质上,每次添加边长为sideLength
的方块时,首先需要找到该方块所在x轴区间[left, left + sideLength - 1]
的最大高度h
,然后将该区间的最大高度更新为h + sideLength
。不难得知此类区间更新+查询类型题目可使用线段树解决,每个节点维护一个区间的最大高度,在每次添加方块的时候,首先查询该区间的最大高度,然后更新该区间的最大高度,并将总的最大高度添加到结果中。
该题目之所以为Hard,是因为题意中的数据量较大,且left
的范围较广,如果使用传统的基于数组的实现方法会导致内存溢出,此时可考虑使用哈希表来存储线段树的节点,以减少内存的使用(因为暂时用不到的节点可以先不存储在哈希表中,节点按需创建)。此外,还需要用到延迟标记来实现区间更新的懒惰更新,避免区间范围过大时导致递归很深且内存占用过大。
class SegTree:
def __init__(self, n: int):
self._n = n
self._tree = {} # 存储每个节点对应区间的最大值
self._delay = {} # 延迟更新标记
def _delay_update(self, cur_idx: int):
""" 处理cur_idx对应区间的节点的延迟更新标记
更新两个子节点的数据,并将更新标记下发到两个子节点中(当遍历到子节点的时候再继续往下下发)
"""
lc_idx = cur_idx * 2 + 1
rc_idx = cur_idx * 2 + 2
# 当前节点有延迟更新标记(意味着子节点数据还没有更新)
# 更新两个子节点的数据
# 并将延迟标记下发到孩子节点中
self._tree[lc_idx] = self._delay[lc_idx] = self._delay[cur_idx]
self._tree[rc_idx] = self._delay[rc_idx] = self._delay[cur_idx]
# 下发完毕, 删除本节点的延迟更新标记
del self._delay[cur_idx]
def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int:
""" 递归搜索以获取区间[l, r]的最大值
:param l 待搜索区间的左侧
:param r 待搜索区间的右侧
:param s 当前区间左侧
:param t 当前区间右侧
:param cur_idx 当前区间的节点序号
"""
if l <= s and t <= r: # 当前区间[s, t]包含在待搜索区间的[l, r]中,直接返回当前节点所保存的区间[s, t]最大值
return self._tree.get(cur_idx, 0)
if t < l or s > r: # [s, t]和[l, r]没有交集, 提前return
return 0
if self._delay.get(cur_idx, 0):
# 有延迟更新标记, 处理下
self._delay_update(cur_idx)
lc_idx = cur_idx * 2 + 1
rc_idx = cur_idx * 2 + 2
mid = (s + t) // 2
# 递归下去搜索
return max(
self._search_helper(l, r, lc_idx, s, mid),
self._search_helper(l, r, rc_idx, mid + 1, t)
)
def search(self, l: int, r: int) -> int:
return self._search_helper(l, r, 0, 0, self._n - 1)
def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int):
""" 递归更新区间[l, r]的最大值为val
:param l 待更新区间的左侧
:param r 待更新区间的右侧
:param val 更新后的值
:param s 当前区间左侧
:param t 当前区间右侧
:param cur_idx 当前区间的节点序号
"""
if t < l or s > r: # [s, t]和[l, r]没有交集, 提前return
return
if l <= s and t <= r:
# [s, t]包含在[l, r]中,则先**仅对该节点**进行更新操作,并设置延迟更新标记
self._tree[cur_idx] = self._delay[cur_idx] = val
return
if self._delay.get(cur_idx, 0):
# 更新该节点前,如果该节点有延迟更新标记(上轮更新导致的), 先处理上一轮的延迟更新后,再进行本轮的更新
self._delay_update(cur_idx)
lc_idx = cur_idx * 2 + 1
rc_idx = cur_idx * 2 + 2
mid = (s + t) // 2
self._update_helper(l, r, val, lc_idx, s, mid)
self._update_helper(l, r, val, rc_idx, mid + 1, t)
# 更新操作完毕后,也需要更新本节点的值(因为区间有交集),做法为取两个子节点中的最大值即可
self._tree[cur_idx] = max(self._tree.get(lc_idx, 0), self._tree.get(rc_idx, 0))
def update(self, l: int, r: int, val: int):
self._update_helper(l, r, val, 0, 0, self._n - 1)
def root_val(self):
""" 返回整个数组的最大值 """
return self._tree[0]
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
r_bound = max(left + side for left, side in positions)
seg_tree = SegTree(r_bound + 1)
ans = []
for left, side in positions:
cur_max_height = seg_tree.search(left, left + side - 1)
seg_tree.update(left, left + side - 1, cur_max_height + side)
ans.append(seg_tree.root_val())
return ans
MLE的做法,因为没有做延迟标记,导致递归深度过大,内存占用过大。
class SegTree:
def __init__(self, n: int):
self._n = n
self._tree = {}
def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int:
if l <= s and t <= r:
return self._tree.get(cur_idx, 0)
if t < l or s > r:
return 0
mid = (s + t) // 2
return max(
self._search_helper(l, r, 2 * cur_idx + 1, s, mid),
self._search_helper(l, r, 2 * cur_idx + 2, mid + 1, t)
)
def search(self, l: int, r: int) -> int:
return self._search_helper(l, r, 0, 0, self._n - 1)
def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int):
if t < l or s > r:
return
if s == t: # !!! 注意这里直至叶子节点才更新
self._tree[cur_idx] = val
return
mid = (s + t) // 2
self._update_helper(l, r, val, 2 * cur_idx + 1, s, mid)
self._update_helper(l, r, val, 2 * cur_idx + 2, mid + 1, t)
self._tree[cur_idx] = max(self._tree.get(cur_idx * 2 + 1, 0), self._tree.get(cur_idx * 2 + 2, 0))
def update(self, l: int, r: int, val: int):
self._update_helper(l, r, val, 0, 0, self._n - 1)
def root_val(self):
return self._tree[0]
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
r_bound = max(left + side for left, side in positions)
seg_tree = SegTree(r_bound + 1)
ans = []
for left, side in positions:
cur_max_height = seg_tree.search(left, left + side - 1)
seg_tree.update(left, left + side - 1, cur_max_height + side)
ans.append(seg_tree.root_val())
return ans