Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a254769
add return_token_accuracy flag to fused_linear_cross_entropy
kashif Oct 16, 2025
b670b7d
rename to token_accuracy
kashif Oct 17, 2025
e872bf4
return token_accuracy in transformer models
kashif Oct 17, 2025
d11b24d
formatting
kashif Oct 17, 2025
002c0ec
add missing output class
kashif Oct 17, 2025
d67c511
typos
kashif Oct 17, 2025
3a4a883
more typos
kashif Oct 17, 2025
1e6da16
added test_correctness_with_token_accuracy
kashif Oct 17, 2025
e9d0954
formatting
kashif Oct 17, 2025
038035d
consistency
kashif Oct 17, 2025
2212623
use CrossEntropyOutput
kashif Oct 20, 2025
33a999b
Merge branch 'main' into mean_token_accuracy
kashif Oct 20, 2025
a50e03e
update qwen3 next
kashif Oct 20, 2025
338e70a
formatting
kashif Oct 20, 2025
d1d9f52
add missing return_dict
kashif Oct 20, 2025
c5857fd
Merge branch 'main' into mean_token_accuracy
kashif Oct 21, 2025
ddfdb0b
Merge branch 'main' into mean_token_accuracy
kashif Oct 25, 2025
f268c27
Merge branch 'main' into mean_token_accuracy
shimizust Oct 28, 2025
704c3b4
Merge branch 'main' into mean_token_accuracy
kashif Nov 1, 2025
c6c2d27
Merge branch 'main' into mean_token_accuracy
kashif Nov 5, 2025
181b11f
checktyle fixes
vaibhavjindal Nov 5, 2025
a06c5db
Merge branch 'main' into mean_token_accuracy
vaibhavjindal Nov 5, 2025
0069dcf
fix qwen3_vl
kashif Nov 5, 2025
3ef06ee
checkstyle
kashif Nov 5, 2025
f855c29
Merge branch 'main' into mean_token_accuracy
vaibhavjindal Nov 5, 2025
dd0790c
fix circular import
vaibhavjindal Nov 5, 2025
d20e8b6
fix output classes for different transformers versions
vaibhavjindal Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
loss_ptr,
z_loss_ptr,
loss_stride,
token_accuracy_ptr,
token_accuracy_stride,
n_cols,
n_non_ignore,
sum_non_ignore_weight,
Expand All @@ -42,6 +44,7 @@ def liger_cross_entropy_kernel(
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_TOKEN_ACCURACY: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
Expand All @@ -60,6 +63,8 @@ def liger_cross_entropy_kernel(
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
Expand All @@ -69,7 +74,8 @@ def liger_cross_entropy_kernel(
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
Expand All @@ -92,11 +98,17 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
# For ignored tokens, set token accuracy to 0
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride
tl.store(token_accuracy_ptr, 0.0)
return

loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
Expand All @@ -107,6 +119,7 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
Expand All @@ -127,6 +140,16 @@ def liger_cross_entropy_kernel(
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)

# Track argmax for accuracy computation
if RETURN_TOKEN_ACCURACY and block_max > m:
# Find the index of the maximum value in this block
is_max_mask = X_block == block_max
# Mask out invalid indices with a value larger than n_cols
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
# Get the first (smallest) index where max occurs
argmax_idx = tl.min(masked_offsets)

if label_smoothing > 0:
# scale X beforehand to avoid overflow
if HAS_WEIGHT:
Expand Down Expand Up @@ -256,6 +279,10 @@ def liger_cross_entropy_kernel(
tl.store(loss_ptr, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)
if RETURN_TOKEN_ACCURACY:
# Store 1.0 if prediction is correct, 0.0 otherwise
is_correct = 1.0 if argmax_idx == y else 0.0
tl.store(token_accuracy_ptr, is_correct)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -274,8 +301,12 @@ def cross_entropy_forward(
reduction,
softcap,
return_z_loss,
return_token_accuracy=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)

BT, V = _input.shape
n_rows = BT
Expand All @@ -285,6 +316,9 @@ def cross_entropy_forward(
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = (
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
)

target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
Expand Down Expand Up @@ -321,6 +355,10 @@ def cross_entropy_forward(
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d,
token_accuracy_stride=token_accuracy_1d.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
n_cols=V,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
Expand All @@ -331,6 +369,7 @@ def cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
Expand All @@ -343,11 +382,14 @@ def cross_entropy_forward(
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None

return loss, z_loss, _input
return loss, z_loss, token_accuracy, _input


def cross_entropy_backward(_input, grad_output):
Expand Down Expand Up @@ -395,6 +437,7 @@ def forward(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -409,14 +452,15 @@ def forward(
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`

Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
"""
input_requires_grad = _input.requires_grad

loss, z_loss, _input = cross_entropy_forward(
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
_input,
target,
weight,
Expand All @@ -426,30 +470,35 @@ def forward(
reduction,
softcap,
return_z_loss,
return_token_accuracy,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
if input_requires_grad:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_token_accuracy = return_token_accuracy

return loss, z_loss
return loss, z_loss, token_accuracy

@staticmethod
def backward(ctx, grad_output, grad_ouput2):
def backward(ctx, grad_output, grad_output2, grad_output3):
"""
The backward pass of the Liger Cross Entropy loss.

Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tenosr): No use.
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics

(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
Expand All @@ -463,4 +512,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
31 changes: 27 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
return_z_loss=False,
accum_dtype=None,
use_token_scaling=False,
return_token_accuracy=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
device = _input.device

input_requires_grad = _input.requires_grad
Expand Down Expand Up @@ -64,6 +68,7 @@ def fused_linear_cross_entropy_forward(

loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None

# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
target_mask = target != ignore_index
Expand Down Expand Up @@ -129,6 +134,7 @@ def fused_linear_cross_entropy_forward(
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
Expand All @@ -144,6 +150,10 @@ def fused_linear_cross_entropy_forward(
loss_ptr=loss_1d_slice,
z_loss_ptr=z_loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d_slice,
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
n_cols=V,
n_non_ignore=total_n_non_ignore,
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
Expand All @@ -154,6 +164,7 @@ def fused_linear_cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=input_requires_grad,
Expand All @@ -170,6 +181,8 @@ def fused_linear_cross_entropy_forward(
loss_1d[start_idx:end_idx] = loss_1d_slice
if return_z_loss:
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
if return_token_accuracy:
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
grad_logits_chunk = logits_chunk # chunk_size x V

# Apply token scaling to gradients if requested
Expand Down Expand Up @@ -201,15 +214,18 @@ def fused_linear_cross_entropy_forward(
# Return per-token losses
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None

# Cast back to original dtype
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None

return loss, z_loss, grad_input, grad_weight, grad_bias
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias


def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
Expand Down Expand Up @@ -277,6 +293,7 @@ def forward(
return_z_loss: bool = False,
accum_dtype=None,
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -300,9 +317,10 @@ def forward(
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
Default: False.
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
"""

loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input=_input,
weight=weight,
target=target,
Expand All @@ -316,6 +334,7 @@ def forward(
return_z_loss=return_z_loss,
accum_dtype=accum_dtype,
use_token_scaling=use_token_scaling,
return_token_accuracy=return_token_accuracy,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -324,13 +343,16 @@ def forward(
grad_bias.detach() if bias is not None else None,
)
ctx.return_z_loss = return_z_loss
return loss, z_loss
ctx.return_token_accuracy = return_token_accuracy
return loss, z_loss, token_accuracy

@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output, grad_output2):
def backward(ctx, grad_output, grad_output2, grad_output3):
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
Expand All @@ -349,4 +371,5 @@ def backward(ctx, grad_output, grad_output2):
None,
None,
None, # use_token_scaling
None, # return_token_accuracy
)
11 changes: 8 additions & 3 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.transformers.functional import CrossEntropyOutput


class LigerCrossEntropyLoss(torch.nn.Module):
Expand All @@ -15,6 +16,7 @@ def __init__(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
Expand All @@ -33,9 +35,10 @@ def __init__(
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss
self.return_token_accuracy = return_token_accuracy

def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
Expand All @@ -45,7 +48,9 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor):
self.reduction,
self.softcap,
self.return_z_loss,
self.return_token_accuracy,
)
if not self.return_z_loss:
if not self.return_z_loss and not self.return_token_accuracy:
return loss
return loss, z_loss

return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
Loading
Loading