Skip to content

Commit 4cb1a25

Browse files
kashifqgallouedec
andauthored
[SFT] Log mean token accuracy from Liger kernel (#4302)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 468b9d4 commit 4cb1a25

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ kernels = [
5656
"kernels"
5757
]
5858
liger = [
59-
"liger-kernel>=0.6.2"
59+
"liger-kernel>=0.6.4"
6060
]
6161
peft = [
6262
"peft>=0.8.0"
@@ -104,7 +104,7 @@ dev = [
104104
# kernels
105105
"kernels",
106106
# liger
107-
"liger-kernel>=0.6.2",
107+
"liger-kernel>=0.6.4",
108108
# peft
109109
"peft>=0.8.0",
110110
# quality

tests/test_grpo_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,7 +1518,6 @@ def reward_func(completions, **kwargs):
15181518
num_generations=3, # reduce the number of generations to reduce memory usage
15191519
max_completion_length=8, # reduce the completion length to reduce memory usage
15201520
use_liger_kernel=True, # enable Liger kernel
1521-
loss_type="bnpo", # default dapo is not supported yet
15221521
report_to="none",
15231522
)
15241523
trainer = GRPOTrainer(
@@ -1839,7 +1838,6 @@ def test_training_with_liger_grpo_kernel(self, model_name):
18391838
max_completion_length=self.max_length,
18401839
report_to="none",
18411840
logging_strategy="no",
1842-
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
18431841
)
18441842

18451843
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -1888,7 +1886,6 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
18881886
max_completion_length=self.max_length,
18891887
report_to="none",
18901888
logging_strategy="no",
1891-
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
18921889
)
18931890

18941891
model = AutoModelForCausalLM.from_pretrained(model_name)

trl/import_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from transformers.utils.import_utils import _is_package_available
2424

2525

26-
LIGER_KERNEL_MIN_VERSION = "0.5.8"
26+
LIGER_KERNEL_MIN_VERSION = "0.6.4"
2727

2828
# Use same as transformers.utils.import_utils
2929
_deepspeed_available = _is_package_available("deepspeed")

trl/trainer/sft_trainer.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -822,18 +822,19 @@ def __init__(
822822
)
823823

824824
# Loss function
825-
if args.loss_type == "nll":
826-
pass # use the default loss
827-
elif args.loss_type == "dft":
828-
if compute_loss_func is not None:
829-
raise ValueError(
830-
"You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
831-
"When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
832-
"`compute_loss_func` is not allowed."
833-
)
834-
compute_loss_func = dft_loss
835-
else:
836-
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
825+
if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True
826+
if args.loss_type == "nll":
827+
pass # use the default loss
828+
elif args.loss_type == "dft":
829+
if compute_loss_func is not None:
830+
raise ValueError(
831+
"You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
832+
"When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so "
833+
"passing a `compute_loss_func` is not allowed."
834+
)
835+
compute_loss_func = dft_loss
836+
else:
837+
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
837838

838839
# Initialize the metrics
839840
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
@@ -1113,6 +1114,11 @@ def compute_loss(
11131114

11141115
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
11151116
inputs["use_cache"] = False
1117+
# Request token accuracy from Liger kernel and set token scaling if using DFT loss
1118+
if self.args.use_liger_kernel:
1119+
inputs["return_token_accuracy"] = True
1120+
inputs["use_token_scaling"] = self.args.loss_type == "dft"
1121+
11161122
(loss, outputs) = super().compute_loss(
11171123
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
11181124
)
@@ -1151,8 +1157,11 @@ def compute_loss(
11511157
self._total_train_tokens += num_tokens_in_batch
11521158
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
11531159

1154-
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
1155-
if not self.args.use_liger_kernel:
1160+
if self.args.use_liger_kernel:
1161+
token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item()
1162+
self._metrics[mode]["mean_token_accuracy"].append(token_accuracy)
1163+
else:
1164+
# Compute accuracy from logits using argmax (traditional method)
11561165
with torch.no_grad():
11571166
if "shift_labels" in inputs:
11581167
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
@@ -1190,10 +1199,12 @@ def compute_loss(
11901199
total_sum = total_tokens.sum()
11911200
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
11921201
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
1193-
if self.aux_loss_enabled:
1194-
aux_loss = outputs.aux_loss
1195-
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
1196-
self._metrics[mode]["aux_loss"].append(aux_loss)
1202+
1203+
# Log auxiliary loss if enabled (applies to both Liger and non-Liger)
1204+
if self.aux_loss_enabled:
1205+
aux_loss = outputs.aux_loss
1206+
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
1207+
self._metrics[mode]["aux_loss"].append(aux_loss)
11971208

11981209
return (loss, outputs) if return_outputs else loss
11991210

0 commit comments

Comments
 (0)