diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py index daefbde3a6..c89521ccfc 100644 --- a/recipes/ppo_full_finetune_single_device.py +++ b/recipes/ppo_full_finetune_single_device.py @@ -935,6 +935,7 @@ def train(self) -> None: curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -1034,6 +1035,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None)