diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 6a2b820b4ff8b7..74454d244e8d31 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -924,6 +924,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" @@ -1046,6 +1048,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))