Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <austin362667@gmail.com>
  • Loading branch information
austin362667 committed Dec 17, 2024
1 parent f951da3 commit 65bcc2c
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
beta (float): Weight for the CPO loss
"""
logits = beta * (chosen_logps - rejected_logps)
loss = - F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = -F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
return loss

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def preference_loss_fn(
rejected_logratios = rejected_logps - ref_rejected_logps

logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = - F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
return loss

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
loss = - beta * ratio.sum() / (full_target.shape[0] // 2)
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def preference_loss_fn(
gamma (float): gemma margin term
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = - F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = -F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
return loss

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def alignment_loss(
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = - (
losses = -(
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = - (
losses = -(
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
Expand Down
2 changes: 1 addition & 1 deletion test/chunked_loss/test_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def alignment_loss(
- torch.log1p(-torch.exp(policy_rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = - self.beta * ratio
losses = -self.beta * ratio

chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps
Expand Down

0 comments on commit 65bcc2c

Please sign in to comment.