Skip to content

Commit

Permalink
Fixing a bug when MlFlow try to log a torch.tensor (#29932)
Browse files Browse the repository at this point in the history
* Update integration_utils.py

Add the case where a tensor with one element is log with Mlflow

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update integration_utils.py add a whitespace

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
etiennebonnafoux and amyeroberts authored Apr 10, 2024
1 parent 0fe4405 commit 3280b13
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,8 @@ def on_log(self, args, state, control, logs, model=None, **kwargs):
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
elif isinstance(v, torch.Tensor) and v.numel() == 1:
metrics[k] = v.item()
else:
logger.warning(
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
Expand Down

0 comments on commit 3280b13

Please sign in to comment.