Skip to content

Commit 4a3b584

Browse files
jue-jue-zikashif
andauthored
fix: use shift_labels for metrics when using CP or SP (#4579)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent d2e4315 commit 4a3b584

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

trl/trainer/sft_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,8 @@ def compute_loss(
11181118

11191119
# Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used.
11201120
# This can be removed when this issue is fixed.
1121-
labels = inputs["labels"]
1121+
# When using CP or SP, labels are pre-shifted, we must use shift_labels instead.
1122+
labels = inputs["labels"] if "shift_labels" not in inputs else None
11221123

11231124
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
11241125
inputs["use_cache"] = False
@@ -1172,7 +1173,7 @@ def compute_loss(
11721173
# Compute accuracy from logits using argmax (traditional method)
11731174
with torch.no_grad():
11741175
if "shift_labels" in inputs:
1175-
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
1176+
# When using CP or SP, labels are pre-shifted. We must use these (and cannot manually shift) because:
11761177
# - The first discarded token from inputs["labels"] actually belongs to process n-1
11771178
# - The last logits require the label from process n+1
11781179
shift_logits = outputs.logits.contiguous()

0 commit comments

Comments
 (0)