From e9018eb6826489b236ab59be962e7c3103241208 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 15 Sep 2023 19:27:18 -0700 Subject: [PATCH] [CE] Implement CrossEntropyLoss in Triton --- csrc/xentropy/README.md | 5 + flash_attn/losses/cross_entropy.py | 147 +++------- flash_attn/ops/triton/cross_entropy.py | 293 ++++++++++++++++++++ tests/losses/test_cross_entropy.py | 28 +- tests/losses/test_cross_entropy_parallel.py | 26 +- 5 files changed, 368 insertions(+), 131 deletions(-) create mode 100644 flash_attn/ops/triton/cross_entropy.py diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md index 7970f3939..1bc90fdab 100644 --- a/csrc/xentropy/README.md +++ b/csrc/xentropy/README.md @@ -7,3 +7,8 @@ It has only been tested on A100s. ```sh cd csrc/xentropy && pip install . ``` + +As of 2023-09-15, this extension is no longer used in the FlashAttention repo. +We've instead switched to a Triton-based +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). +See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details. diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index c9cd1776e..bb4e02eb9 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -2,115 +2,10 @@ # But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and # the losses we can get the global loss. There's no need to do it step by step # (compute local max, exchange, compute exp, compute local sum, exchange, etc.) -# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py import torch import torch.nn as nn -import xentropy_cuda_lib -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -class SoftmaxCrossEntropyLossFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - logits, - labels, - smoothing=0.0, - ignored_index=-100, - inplace_backward=False, - process_group=None, - ): - """ - logits: (batch, vocab_size) - labels: (batch,) - If process_group is not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss needs to be aggregated across processes. - """ - batch, vocab_size = logits.shape - assert labels.shape == (batch,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - ctx.total_classes = world_size * vocab_size - - if world_size == 1: - losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) - losses.masked_fill_(labels == ignored_index, 0) - labels_local = labels - else: - rank = torch.distributed.get_rank(process_group) - vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) - - # For tensor parallel cross entropy with smoothing, we want to pass in the total number - # of classes so that smoothing can be applied correctly. If total_classes=-1, use the - # last dimension of the input tensor. - losses, lse_local = xentropy_cuda_lib.forward( - logits, labels_local, smoothing, world_size * vocab_size - ) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(ignored_mask, 0) - # For labels == ignored_index, the loss is always 0. - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # lse_local - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) - # For labels not in the vocab of this partition, losses contains - # 0.1 * (lse_local - sum logit / total_classes). - - lse_allgather = torch.empty( - world_size, batch, dtype=lse_local.dtype, device=lse_local.device - ) - torch.distributed.all_gather_into_tensor( - lse_allgather, lse_local.contiguous(), group=process_group - ) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - # If there's no smoothing, the total losses are lse_local - predicted_logit, - # we just have to subtract the lse_local and add the lse (global). - # If there's smoothing=0.1, the total losses are - # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) - # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). - rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") - lse_local = lse_allgather[ - rank_per_sample, torch.arange(batch, device=lse_allgather.device) - ] - - handle_losses.wait() - if smoothing == 0.0: - losses += lse - lse_local - else: - losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( - lse - lse_allgather.sum(dim=0) - ) - losses.masked_fill_(ignored_mask, 0) - - ctx.save_for_backward(logits, lse, labels_local) - ctx.smoothing = smoothing - ctx.ignored_index = ignored_index - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, lse, labels = ctx.saved_tensors - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels == ctx.ignored_index, 0) - grad_logits = xentropy_cuda_lib.backward( - grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes - ) - return grad_logits, None, None, None, None, None, None +from flash_attn.ops.triton.cross_entropy import cross_entropy_loss class CrossEntropyLoss(nn.Module): @@ -119,30 +14,52 @@ def __init__( ignore_index=-100, reduction="mean", label_smoothing=0.0, + lse_square_scale=0.0, inplace_backward=False, process_group=None, ): + """ + Arguments: + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + """ super().__init__() - if reduction not in ["mean", "none"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") self.ignore_index = ignore_index self.reduction = reduction self.label_smoothing = label_smoothing + self.lse_square_scale = lse_square_scale self.inplace_backward = inplace_backward self.process_group = process_group def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossFn.apply( + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss = cross_entropy_loss( input, target, - self.label_smoothing, - self.ignore_index, - self.inplace_backward, - self.process_group, + label_smoothing=self.label_smoothing, + lse_square_scale=self.lse_square_scale, + ignored_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, ) if self.reduction == "mean": return loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + return loss.sum() else: return loss diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py new file mode 100644 index 000000000..cc7fabdd2 --- /dev/null +++ b/flash_attn/ops/triton/cross_entropy.py @@ -0,0 +1,293 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Tuple, Optional, Union + +import torch + +from einops import rearrange + +import triton +import triton.language as tl + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + logits_ptr, + labels_ptr, + smoothing, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + n_rows, + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) + max_logits = tl.max(logits, 0) + if HAS_SMOOTHING: + sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) + lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) + if label_idx == ignored_index: + loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( + n_cols, (col_block_idx + 1) * BLOCK_SIZE + ): + logits_label = tl.load(logits_ptr + label_idx) + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - smoothing * sum_logits / total_classes + - (1 - smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss + if HAS_SMOOTHING: + loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + loss += lse_square_scale * lse * lse + tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + smoothing, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignored_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_positive = 1.0 - smoothing + smooth_negative = smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols) + + +class CrossEntropyLoss(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing, + lse_square_scale=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + n_rows, n_cols = logits.shape + assert labels.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + + if logits.stride(-1) != 1: + logits = logits.contiguous() + # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py + MAX_BLOCK_SIZE = 64 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + # We may split the lse computation across multiple blocks, then do a reduction + # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) + # where having just one thread block processing more than 64k elements is slow. + split = world_size > 1 or n_cols > MAX_BLOCK_SIZE + n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) + losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_fwd_kernel[(n_rows, n_splits)]( + losses, # data ptrs + lse, + logits, + labels, + smoothing, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, + n_cols, # shapes + n_rows, + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + SPLIT=split, + ) + + if split: + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For labels not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + else: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) + # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, + # we just have to add the lse (global). + # If there's smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the lse(global) + losses += lse + if lse_square_scale != 0.0: + losses += lse_square_scale * lse.square() + losses.masked_fill_(labels == ignored_index, 0.0) + + ctx.save_for_backward(logits, lse, labels) + ctx.smoothing = smoothing + ctx.lse_square_scale = lse_square_scale + ctx.ignored_index = ignored_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_losses): + logits, lse, labels = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + labels, + ctx.smoothing, + ctx.lse_square_scale, + ctx.ignored_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + labels: torch.Tensor, + label_smoothing: float = 0.0, + lse_square_scale: float = 0.0, + ignored_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: (batch, vocab_size) + labels: (batch,) + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: (batch,), float + """ + return CrossEntropyLoss.apply( + logits, + labels, + label_smoothing, + lse_square_scale, + ignored_index, + inplace_backward, + process_group, + ) diff --git a/tests/losses/test_cross_entropy.py b/tests/losses/test_cross_entropy.py index 10882dc28..21c43b0e7 100644 --- a/tests/losses/test_cross_entropy.py +++ b/tests/losses/test_cross_entropy.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from flash_attn.losses.cross_entropy import CrossEntropyLossApex +from flash_attn.losses.cross_entropy import CrossEntropyLoss is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @@ -12,12 +12,16 @@ @pytest.mark.parametrize( "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("inplace_backward", [False, True]) -# @pytest.mark.parametrize('inplace_backward', [False]) +# @pytest.mark.parametrize("inplace_backward", [False]) +@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) +# @pytest.mark.parametrize("lse_square_scale", [1e-2]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) -@pytest.mark.parametrize("vocab_size", [50257]) -def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): +# @pytest.mark.parametrize("smoothing", [0.0]) +@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split +# @pytest.mark.parametrize("vocab_size", [12]) +def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype): device = "cuda" rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # set seed @@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype) ) x = x_pt.detach().clone().requires_grad_() y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) - y[torch.randperm(batch_size * seqlen)[:10]] = -100 + if batch_size * seqlen > 10: + y[torch.randperm(batch_size * seqlen)[:10]] = -100 model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) - model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward) + model = CrossEntropyLoss( + label_smoothing=smoothing, + lse_square_scale=lse_square_scale, + inplace_backward=inplace_backward, + ) out = model(x, y) out_pt = model_pt(x_pt.float(), y) - assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) + if lse_square_scale > 0.0: + lse_pt = torch.logsumexp(x_pt.float(), dim=-1) + out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean() + assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) g = torch.randn_like(out) out_pt.backward(g) diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index 4aa0518c7..2588a11a6 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -1,5 +1,5 @@ # Run test with: -# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py +# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/losses/test_cross_entropy_parallel.py import math @@ -15,15 +15,20 @@ @pytest.mark.parametrize( "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("inplace_backward", [False, True]) -# @pytest.mark.parametrize('inplace_backward', [False]) +# @pytest.mark.parametrize("inplace_backward", [False]) +@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) +# @pytest.mark.parametrize("lse_square_scale", [1e-2]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) -# @pytest.mark.parametrize('smoothing', [0.9]) -@pytest.mark.parametrize("vocab_size", [50264]) -@pytest.mark.parametrize("world_size", [1, 2, 4, 8]) -# @pytest.mark.parametrize('world_size', [2]) -def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype): +# @pytest.mark.parametrize("smoothing", [0.0]) +@pytest.mark.parametrize("vocab_size", [50264, 128 * 1024]) # test vocab larger than 64k for split +# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split +@pytest.mark.parametrize("world_size", [1, 2, 4]) +# @pytest.mark.parametrize("world_size", [2]) +def test_cross_entropy_loss_parallel( + vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype +): assert vocab_size % world_size == 0 rtol, atol = ( (1e-5, 1e-6) @@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_ model = CrossEntropyLoss( label_smoothing=smoothing, reduction="none", + lse_square_scale=lse_square_scale, inplace_backward=inplace_backward, process_group=parallel_state.get_tensor_model_parallel_group(), ) out = model(x, y) out_pt = model_pt(x_pt.float(), y) + if lse_square_scale > 0.0: + lse_pt = torch.logsumexp(x_pt.float(), dim=-1) + out_pt += lse_square_scale * lse_pt.square() + out_pt.masked_fill_(y == -100, 0.0) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) g = torch.randn_like(out)