File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -1118,7 +1118,8 @@ def compute_loss(
11181118
11191119 # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used.
11201120 # This can be removed when this issue is fixed.
1121- labels = inputs ["labels" ]
1121+ # When using CP or SP, labels are pre-shifted, we must use shift_labels instead.
1122+ labels = inputs ["labels" ] if "shift_labels" not in inputs else None
11221123
11231124 # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
11241125 inputs ["use_cache" ] = False
@@ -1172,7 +1173,7 @@ def compute_loss(
11721173 # Compute accuracy from logits using argmax (traditional method)
11731174 with torch .no_grad ():
11741175 if "shift_labels" in inputs :
1175- # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
1176+ # When using CP or SP , labels are pre-shifted. We must use these (and cannot manually shift) because:
11761177 # - The first discarded token from inputs["labels"] actually belongs to process n-1
11771178 # - The last logits require the label from process n+1
11781179 shift_logits = outputs .logits .contiguous ()
You can’t perform that action at this time.
0 commit comments