Skip to content

Commit

Permalink
Add KTO Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 authored and Biao He committed Dec 12, 2024
1 parent 55e3755 commit 4471ba6
Show file tree
Hide file tree
Showing 7 changed files with 463 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ site/
.venv/
venv/
.ipynb_checkpoints/
.vscode/

# Misc
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Liger FlexChunkLoss: Alignment and Distillation loss

Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.

### User interface

Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
2 changes: 2 additions & 0 deletions src/liger_kernel/chunked_loss/functional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
144 changes: 144 additions & 0 deletions src/liger_kernel/chunked_loss/kto_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)


class LigerFusedLinearKTOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps,
rejected_logps,
ref_chosen_logps=None,
ref_rejected_logps=None,
beta=0.1,
):
"""
Paper: https://arxiv.org/abs/2402.01306
Formula:
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
Where:
- σ: Sigmoid function
- β: Temperature parameter
- KL(π||π₀)_y is KL divergence estimated using the rejected response y
Args:
chosen_logps: Log probabilities of chosen tokens (batch_size,)
rejected_logps: Log probabilities of rejected tokens (batch_size,)
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
beta: Weight for the direct preference loss
"""
if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)

chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps

kl = torch.zeros(1).to(chosen_logps.device)
# chosen_KL = chosen_logratios.mean().clamp(min=0)
# rejected_KL = rejected_logratios.mean().clamp(min=0)

losses = torch.cat(
(
1 - F.sigmoid(beta * (chosen_logratios - kl)),
1 - F.sigmoid(beta * (kl - rejected_logratios)),
),
0,
)

chosen_rewards = beta * chosen_logratios.detach()
rejected_rewards = beta * rejected_logratios.detach()

return losses, chosen_rewards, rejected_rewards

@staticmethod
def forward(
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None


class LigerFusedLinearKTOLoss(torch.nn.Module):
"""
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss calculation
beta (float): Temperature parameter for the KTO loss
compute_nll_loss (bool): Whether to compute the NLL loss alongside KTO
compiled (bool): Whether to use compiled operations
use_ref_model (bool): Whether to use a reference model for the DPO loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model

def forward(
self,
lin_weight,
_input,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearKTOFunction.apply(
_input,
lin_weight,
target,
bias,
ref_input,
ref_weight,
ref_bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
)
4 changes: 2 additions & 2 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

class HFDPOLoss(HFAlignmentLoss):
"""
Implementation of the Odds Ratio Preference Optimization (ORPO) loss,
Implementation of the Direct Preference Optimization (DPO) loss,
adapted from Hugging Face's implementation.
Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py
Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py
"""

def __init__(
Expand Down
Loading

0 comments on commit 4471ba6

Please sign in to comment.