Skip to content

Commit

Permalink
fix chosen_nll_loss in chunked loses
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 17, 2024
1 parent ac56674 commit d8e457b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 6 deletions.
11 changes: 10 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
alpha=1.0,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx,
Expand All @@ -60,12 +61,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -80,18 +82,24 @@ def __init__(
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
alpha (float): Weight for the NLL loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
Expand All @@ -104,4 +112,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.alpha,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
10 changes: 8 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def forward(
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -83,12 +84,13 @@ def forward(
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -103,6 +105,7 @@ def __init__(
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
"""
Args:
Expand All @@ -111,13 +114,15 @@ def __init__(
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

def forward(
self,
Expand All @@ -142,4 +147,5 @@ def forward(
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
self.is_encoder_decoder,
)
22 changes: 20 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def forward(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
Expand Down Expand Up @@ -56,6 +57,7 @@ def forward(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
Expand Down Expand Up @@ -94,6 +96,7 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
**loss_kwargs,
)

Expand Down Expand Up @@ -282,6 +285,7 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
is_encoder_decoder=False,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
Expand All @@ -291,12 +295,23 @@ def chunk_forward(

chosen_nll_loss = 0.0
if compute_nll_loss:
if not is_encoder_decoder:
shifted_logits = log_probs_chunk[:len_chosen_chunk, :-1].contiguous()
shifted_target = target_chunk[:len_chosen_chunk, 1:].contiguous()
else:
shifted_logits = log_probs_chunk[:len_chosen_chunk].contiguous()
shifted_target = target_chunk[:len_chosen_chunk].contiguous()

chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
shifted_logits.view(-1, log_probs_chunk.shape[-1]),
shifted_target.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
else:
chosen_nll_loss = torch.zeros(
(), device=target_chunk.device, dtype=target_chunk.dtype
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)
Expand Down Expand Up @@ -331,6 +346,7 @@ def _compute_loss(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
Expand All @@ -350,6 +366,7 @@ def _compute_loss(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
Expand All @@ -369,6 +386,7 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
is_encoder_decoder=is_encoder_decoder,
)
chosen_nll_loss = (
chosen_nll_loss
Expand Down
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -69,12 +70,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -88,17 +90,22 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
Expand All @@ -110,4 +117,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.beta,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
5 changes: 5 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
self.alpha = alpha
self.beta = beta
self.ignore_index = ignore_index
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

@abstractmethod
def alignment_loss(self):
Expand Down Expand Up @@ -440,6 +442,9 @@ def concatenated_forward(
def cross_entropy_loss(logits, labels):
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
if not self.is_encoder_decoder:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
Expand Down

0 comments on commit d8e457b

Please sign in to comment.