Skip to content

Commit

Permalink
moved labels to the same device as logits for LILT model (huggingface…
Browse files Browse the repository at this point in the history
  • Loading branch information
sushmanthreddy authored and novice03 committed Jun 23, 2023
1 parent c1ebdae commit e6c3db1
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/lilt/modeling_lilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,8 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
Expand Down Expand Up @@ -1046,6 +1048,8 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

Expand Down

0 comments on commit e6c3db1

Please sign in to comment.