From 7451b72cd914c648dda01ba0ef2f1ab8e1201af5 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 4 Mar 2024 09:32:18 -0500 Subject: [PATCH 1/2] Fix test failure --- src/transformers/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 99792019846210..e2445d6a364721 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2011,7 +2011,10 @@ def _inner_training_loop( is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED ): - grad_norm = model.get_global_grad_norm().item() + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() else: grad_norm = _grad_norm.item() if _grad_norm is not None else None From 9d01bdd98091d7b2755c3312b885691cbffc16b2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 4 Mar 2024 09:33:45 -0500 Subject: [PATCH 2/2] use item --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2445d6a364721..056f7a2ca96e34 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2013,7 +2013,7 @@ def _inner_training_loop( ): grad_norm = model.get_global_grad_norm() # In some cases the grad norm may not return a float - if isinstance(grad_norm, torch.Tensor): + if hasattr(grad_norm, "item"): grad_norm = grad_norm.item() else: grad_norm = _grad_norm.item() if _grad_norm is not None else None