From f581c379c99f2593bc9c5309054da181e06919e9 Mon Sep 17 00:00:00 2001 From: Erich Schubert Date: Mon, 19 Feb 2024 19:43:15 +0100 Subject: [PATCH] Move misplaced line Move misplaced line, improve code comment --- src/transformers/models/mistral/modeling_mistral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f4251b98304c4e..fbba155f19d57c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1176,11 +1176,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Enable model parallelism + # Ensure tensors are on the same device shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: