diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 209ff9703ff261..fa02c6a1d6b5c1 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1630,6 +1630,8 @@ def forward( masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token + + labels = labels.to(mlm_logits.device) masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1)) if not return_dict: @@ -1730,6 +1732,8 @@ def forward( itm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) itm_loss = loss_fct(logits, labels) if not return_dict: