From 622ee4ac5f44c6a2d4740f837bf6218a52ffbaa0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 18 Oct 2025 11:47:05 +0200 Subject: [PATCH 1/8] Request token accuracy from Liger kernel if used --- trl/trainer/sft_trainer.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index cee2fb82ede..8d8534dee77 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1095,6 +1095,10 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False + # Request token accuracy from Liger kernel if used + if self.args.use_liger_kernel: + inputs["return_token_accuracy"] = True + (loss, outputs) = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) @@ -1133,8 +1137,12 @@ def compute_loss( self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - if not self.args.use_liger_kernel: + if self.args.use_liger_kernel: + if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: + token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() + self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) + else: + # Compute accuracy from logits using argmax (traditional method) with torch.no_grad(): if "shift_labels" in inputs: # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: @@ -1172,10 +1180,12 @@ def compute_loss( total_sum = total_tokens.sum() accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if self.aux_loss_enabled: - aux_loss = outputs.aux_loss - aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() - self._metrics[mode]["aux_loss"].append(aux_loss) + + # Log auxiliary loss if enabled (applies to both Liger and non-Liger) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) return (loss, outputs) if return_outputs else loss From 1b9ec9f0cd1d8a9b1bb35b7015b8ef6fb2c5f96f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 18 Oct 2025 13:31:30 +0200 Subject: [PATCH 2/8] set token scaling flag --- trl/trainer/sft_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8d8534dee77..f6d67b9bc07 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -818,7 +818,7 @@ def __init__( ) # Loss function - if args.loss_type == "nll": + if args.loss_type == "nll" or args.use_liger_kernel: pass # use the default loss elif args.loss_type == "dft": if compute_loss_func is not None: @@ -1095,9 +1095,10 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False - # Request token accuracy from Liger kernel if used + # Request token accuracy from Liger kernel and set token scaling if using DFT loss if self.args.use_liger_kernel: inputs["return_token_accuracy"] = True + inputs["use_token_scaling"] = self.args.loss_type == "dft" (loss, outputs) = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch From 16e3305f3cca4125c460e89d2522e4ea23be9a1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 20 Oct 2025 23:48:08 +0000 Subject: [PATCH 3/8] clarity --- trl/trainer/sft_trainer.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f8600caa5c5..01c54487117 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -818,18 +818,19 @@ def __init__( ) # Loss function - if args.loss_type == "nll" or args.use_liger_kernel: - pass # use the default loss - elif args.loss_type == "dft": - if compute_loss_func is not None: - raise ValueError( - "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " - "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " - "`compute_loss_func` is not allowed." - ) - compute_loss_func = dft_loss - else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so " + "passing a `compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} From 049129ac87df3a9c75d0f15dbccbecb52ff10e1d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 22 Nov 2025 14:06:49 +0100 Subject: [PATCH 4/8] pin to 0.6.4 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b84ff50d7f..61e8c0f4726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ kernels = [ "kernels" ] liger = [ - "liger-kernel>=0.6.2" + "liger-kernel>=0.6.4" ] peft = [ "peft>=0.8.0" @@ -104,7 +104,7 @@ dev = [ # kernels "kernels", # liger - "liger-kernel>=0.6.2", + "liger-kernel>=0.6.4", # peft "peft>=0.8.0", # quality From 5c903c69dcef7c05c9389c4e29ae66ce700fc03a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 22 Nov 2025 14:15:26 +0100 Subject: [PATCH 5/8] explicit condition --- trl/trainer/sft_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index b9847a9930b..ca56db8101e 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -17,6 +17,7 @@ from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass +from importlib.metadata import version from pathlib import Path from typing import Any @@ -1115,7 +1116,7 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False # Request token accuracy from Liger kernel and set token scaling if using DFT loss - if self.args.use_liger_kernel: + if self.args.use_liger_kernel and version("liger_kernel") >= "0.6.4": inputs["return_token_accuracy"] = True inputs["use_token_scaling"] = self.args.loss_type == "dft" From 9b2856eb8b45d95770f79e4e2d6567ef0175ca0a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 22 Nov 2025 22:22:09 +0100 Subject: [PATCH 6/8] no need for version check as we requre 0.6.4 --- trl/trainer/sft_trainer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index ca56db8101e..c33781d608a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -17,7 +17,6 @@ from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass -from importlib.metadata import version from pathlib import Path from typing import Any @@ -1116,7 +1115,7 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False # Request token accuracy from Liger kernel and set token scaling if using DFT loss - if self.args.use_liger_kernel and version("liger_kernel") >= "0.6.4": + if self.args.use_liger_kernel: inputs["return_token_accuracy"] = True inputs["use_token_scaling"] = self.args.loss_type == "dft" @@ -1159,9 +1158,8 @@ def compute_loss( self._metrics[mode]["num_tokens"] = [self._total_train_tokens] if self.args.use_liger_kernel: - if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: - token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() - self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) + token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() + self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) else: # Compute accuracy from logits using argmax (traditional method) with torch.no_grad(): From ac108e88306149f748091951e0a3d733925c3f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 22 Nov 2025 21:37:04 +0000 Subject: [PATCH 7/8] update version and test --- tests/test_grpo_trainer.py | 2 -- trl/import_utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b3844a399c1..2d12d0d84ba 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1839,7 +1839,6 @@ def test_training_with_liger_grpo_kernel(self, model_name): max_completion_length=self.max_length, report_to="none", logging_strategy="no", - loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620 ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -1888,7 +1887,6 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name): max_completion_length=self.max_length, report_to="none", logging_strategy="no", - loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620 ) model = AutoModelForCausalLM.from_pretrained(model_name) diff --git a/trl/import_utils.py b/trl/import_utils.py index 4d8a9c84ce0..7a062464ada 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -23,7 +23,7 @@ from transformers.utils.import_utils import _is_package_available -LIGER_KERNEL_MIN_VERSION = "0.5.8" +LIGER_KERNEL_MIN_VERSION = "0.6.4" # Use same as transformers.utils.import_utils _deepspeed_available = _is_package_available("deepspeed") From 7d0f90f3bcad3c5097ee80ab85f525cb282f84fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 22 Nov 2025 22:01:26 +0000 Subject: [PATCH 8/8] rm loss type constrain for liger --- tests/test_grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 2d12d0d84ba..c747e0a1aa6 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1518,7 +1518,6 @@ def reward_func(completions, **kwargs): num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage use_liger_kernel=True, # enable Liger kernel - loss_type="bnpo", # default dapo is not supported yet report_to="none", ) trainer = GRPOTrainer(