Skip to content

Commit

Permalink
Fix test failure on DeepSpeed (#29444)
Browse files Browse the repository at this point in the history
* Fix test failure

* use item
  • Loading branch information
muellerzr committed Mar 6, 2024
1 parent 0a5b051 commit 9322576
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,10 @@ def _inner_training_loop(
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm().item()
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None

Expand Down

0 comments on commit 9322576

Please sign in to comment.