Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ kernels = [
"kernels"
]
liger = [
"liger-kernel>=0.6.2"
"liger-kernel>=0.6.4"
]
peft = [
"peft>=0.8.0"
Expand Down Expand Up @@ -104,7 +104,7 @@ dev = [
# kernels
"kernels",
# liger
"liger-kernel>=0.6.2",
"liger-kernel>=0.6.4",
# peft
"peft>=0.8.0",
# quality
Expand Down
3 changes: 0 additions & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,6 @@ def reward_func(completions, **kwargs):
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
use_liger_kernel=True, # enable Liger kernel
loss_type="bnpo", # default dapo is not supported yet
report_to="none",
)
trainer = GRPOTrainer(
Expand Down Expand Up @@ -1839,7 +1838,6 @@ def test_training_with_liger_grpo_kernel(self, model_name):
max_completion_length=self.max_length,
report_to="none",
logging_strategy="no",
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
)

model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down Expand Up @@ -1888,7 +1886,6 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
max_completion_length=self.max_length,
report_to="none",
logging_strategy="no",
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
)

model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down
2 changes: 1 addition & 1 deletion trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.utils.import_utils import _is_package_available


LIGER_KERNEL_MIN_VERSION = "0.5.8"
LIGER_KERNEL_MIN_VERSION = "0.6.4"

# Use same as transformers.utils.import_utils
_deepspeed_available = _is_package_available("deepspeed")
Expand Down
47 changes: 29 additions & 18 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,18 +822,19 @@ def __init__(
)

# Loss function
if args.loss_type == "nll":
pass # use the default loss
elif args.loss_type == "dft":
if compute_loss_func is not None:
raise ValueError(
"You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
"When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
"`compute_loss_func` is not allowed."
)
compute_loss_func = dft_loss
else:
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True
if args.loss_type == "nll":
pass # use the default loss
elif args.loss_type == "dft":
if compute_loss_func is not None:
raise ValueError(
"You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
"When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so "
"passing a `compute_loss_func` is not allowed."
)
compute_loss_func = dft_loss
else:
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")

# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
Expand Down Expand Up @@ -1113,6 +1114,11 @@ def compute_loss(

# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
inputs["use_cache"] = False
# Request token accuracy from Liger kernel and set token scaling if using DFT loss
if self.args.use_liger_kernel:
inputs["return_token_accuracy"] = True
inputs["use_token_scaling"] = self.args.loss_type == "dft"

(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
Expand Down Expand Up @@ -1151,8 +1157,11 @@ def compute_loss(
self._total_train_tokens += num_tokens_in_batch
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if not self.args.use_liger_kernel:
if self.args.use_liger_kernel:
token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item()
self._metrics[mode]["mean_token_accuracy"].append(token_accuracy)
else:
# Compute accuracy from logits using argmax (traditional method)
with torch.no_grad():
if "shift_labels" in inputs:
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
Expand Down Expand Up @@ -1190,10 +1199,12 @@ def compute_loss(
total_sum = total_tokens.sum()
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
if self.aux_loss_enabled:
aux_loss = outputs.aux_loss
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
self._metrics[mode]["aux_loss"].append(aux_loss)

# Log auxiliary loss if enabled (applies to both Liger and non-Liger)
if self.aux_loss_enabled:
aux_loss = outputs.aux_loss
aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item()
self._metrics[mode]["aux_loss"].append(aux_loss)

return (loss, outputs) if return_outputs else loss

Expand Down
Loading