Skip to content

Commit

Permalink
fixup to fix the torch fx?
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 17, 2024
1 parent b14c3dd commit dbbc3ce
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
from .loss_rt_detr import RTDetrForObjectDetectionLoss


def fixed_cross_entropy(source, target, **kwargs):
ignore_index = kwargs.get("ignore_index", -100)
num_items_in_batch = kwargs.get("num_items_in_batch", None)
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss


def ForCausalLMLoss(logits, labels, vocab_size, **kwargs):
def ForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
Expand All @@ -43,7 +43,7 @@ def ForCausalLMLoss(logits, labels, vocab_size, **kwargs):
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(shift_logits, shift_labels, **kwargs)
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss


Expand Down

0 comments on commit dbbc3ce

Please sign in to comment.