diff --git a/llama.py b/llama.py index 1a7a61f..b284702 100644 --- a/llama.py +++ b/llama.py @@ -375,7 +375,10 @@ def model_provider_func(llama_args, *args, **kwargs): def loss_func(pred, label): label = rearrange(label, "b s -> s b").contiguous() - loss = tensor_parallel.vocab_parallel_cross_entropy(pred, label).mean() + logits = tensor_parallel.gather_from_tensor_model_parallel_region( + pred + ) + loss = F.cross_entropy(logits.view(-1, logits.shape[-1]).contiguous(), label.view(-1).contiguous()) averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {"nice_loss": averaged_loss}