diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index f59457c84..d970fcc5f 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -308,10 +308,6 @@ def chunk_forward( reduction="sum", ignore_index=ignore_index, ) - else: - chosen_nll_loss = torch.zeros( - (), device=target_chunk.device, dtype=target_chunk.dtype - ) loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0)