diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index fc16a9a827a920..803f6fe840e7d0 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -526,6 +526,8 @@ def start(self): elif is_torch_npu_available(): self.torch.npu.reset_peak_memory_stats() self.torch.npu.empty_cache() + elif is_torch_mps_available(): + self.torch.mps.empty_cache() # gpu if self.torch is not None: @@ -535,6 +537,8 @@ def start(self): self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() elif is_torch_npu_available(): self.gpu_mem_used_at_start = self.torch.npu.memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory() # cpu self.cpu_mem_used_at_start = self.cpu_mem_used() @@ -564,6 +568,8 @@ def stop(self, stage): self.torch.xpu.empty_cache() elif is_torch_npu_available(): self.torch.npu.empty_cache() + elif is_torch_mps_available(): + self.torch.mps.empty_cache() # concepts: # - alloc_delta: the difference of allocated memory between the end and the start @@ -581,6 +587,11 @@ def stop(self, stage): elif is_torch_npu_available(): self.gpu_mem_used_now = self.torch.npu.memory_allocated() self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_now = self.torch.mps.current_allocated_memory() + # self.torch.mps.max_memory_allocated() does not exist yet + self.gpu_mem_used_peak = None + else: raise ValueError("No available GPU device found!") @@ -588,8 +599,11 @@ def stop(self, stage): "begin": self.gpu_mem_used_at_start, "end": self.gpu_mem_used_now, "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start), - "peaked": max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now), } + if self.gpu_mem_used_peak is not None: + self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now) + else: + self.gpu[self.cur_stage]["peaked"] = "Not available" # cpu self.cpu_mem_used_now = self.cpu_mem_used()