From da24657a0ffb3807fbe8e0adf29eb3eb8cf45422 Mon Sep 17 00:00:00 2001 From: saurabhkoshatwar <35650601+saurabhkoshatwar@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:42:08 -0700 Subject: [PATCH 1/8] Feature/tvd loss fused (#1) * Add fused tvd loss --- benchmark/data/all_benchmark_data.csv | 36 ++++ benchmark/scripts/benchmark_tvd.py | 136 +++++++++++++++ src/liger_kernel/ops/tvd.py | 174 ++++++++++++++++++++ src/liger_kernel/transformers/functional.py | 2 + src/liger_kernel/transformers/tvd.py | 11 ++ test/transformers/test_tvd.py | 129 +++++++++++++++ 6 files changed, 488 insertions(+) create mode 100644 benchmark/scripts/benchmark_tvd.py create mode 100644 src/liger_kernel/ops/tvd.py create mode 100644 src/liger_kernel/transformers/tvd.py create mode 100644 test/transformers/test_tvd.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 32c8d01ab..eab22dd7e 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -505,3 +505,39 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859 fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 diff --git a/benchmark/scripts/benchmark_tvd.py b/benchmark/scripts/benchmark_tvd.py new file mode 100644 index 000000000..6dc931844 --- /dev/null +++ b/benchmark/scripts/benchmark_tvd.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.tvd import LigerTVDLoss + + +class TorchTVDLoss(torch.nn.Module): + def __init__(self, reduction='batchmean'): + super(TorchTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + tvd = torch.abs(p - q) / 2.0 + if self.reduction == 'mean': + return torch.sum(tvd) / (p.size(0) * p.size(1)) + elif self.reduction == 'sum': + return torch.sum(tvd) + elif self.reduction == 'none': + return tvd + elif self.reduction == 'batchmean': + return torch.sum(tvd) / p.size(0) + else: + raise ValueError("Invalid reduction type.") + +S, E = 12, 18 + + +def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, quantiles=QUANTILES, rep=100 + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + common_args = { + "kernel_name": "tvd", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 8, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_tvd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_tvd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py new file mode 100644 index 000000000..4982a8cbc --- /dev/null +++ b/src/liger_kernel/ops/tvd.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + +from typing import Literal +from liger_kernel.ops.utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + +def tv_distance_forward_triton(p, q, reduction): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + reduction=reduction, + ) + + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT, grads + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V), grads + else: + return output_tensor, grads + +def tvd_backward_triton(grad_output, grads): + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, p: torch.Tensor, q: torch.Tensor, reduction: REDUCTION_LITERAL = "batchmean" + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + loss, grads = tv_distance_forward_triton(p, q, reduction) + ctx.save_for_backward(grads) + ctx.reduction = reduction + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. + """ + grads, = ctx.saved_tensors + BT, V = grads.shape + + grads = tvd_backward_triton(grad_output, grads) + + if ctx.reduction == "batchmean": + grads /= BT + elif ctx.reduction == "mean": + grads /= (BT * V) + + return grads, None, None diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index f160887b8..c766ba1c2 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -10,6 +10,7 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction +from liger_kernel.ops.tvd import LigerTVDLossFunction liger_swiglu = LigerSiLUMulFunction.apply liger_cross_entropy = LigerCrossEntropyFunction.apply @@ -21,3 +22,4 @@ liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply +liger_tvd = LigerTVDLossFunction.apply diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py new file mode 100644 index 000000000..9ee0b4908 --- /dev/null +++ b/src/liger_kernel/transformers/tvd.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn +from liger_kernel.ops.tvd import LigerTVDLossFunction + +class LigerTVDLoss(nn.Module): + def __init__(self, reduction='batchmean'): + super(LigerTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + return LigerTVDLossFunction.apply(p, q, self.reduction) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py new file mode 100644 index 000000000..7ef0d9c89 --- /dev/null +++ b/test/transformers/test_tvd.py @@ -0,0 +1,129 @@ +from test.utils import supports_bfloat16 + +import pytest +import torch +from liger_kernel.transformers.tvd import LigerTVDLoss + +class TorchTVDLoss(torch.nn.Module): + def __init__(self, reduction='batchmean'): + super(TorchTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + + tvd = torch.abs(p - q) / 2.0 + + if self.reduction == 'mean': + return torch.sum(tvd) /(p.size(0) * p.size(1)) + elif self.reduction == 'sum': + return torch.sum(tvd) + elif self.reduction == 'none': + return tvd + elif self.reduction == 'batchmean': + return torch.sum(tvd) / p.size(0) + else: + raise ValueError("Invalid reduction type.") + + +_SHAPE_PARAMS = ( + "B, T, V", + [ + (1, 4096, 32000), + (32, 4096, 1024), + (41, 401, 1271), + pytest.param( + 1, + 4096, + 128256, + marks=pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory + < 36 * 1000 * 1000 * 1000, + reason="This test requires a GPU with at least 36GB of memory", + ), + ), + (3, 423, 32000), + ], +) + +_DTYPE_PARAMS = ( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (torch.float32, 1e-8, 1e-6), + (torch.float16, 1e-3, 1e-3), + ], +) + +def _test_correctness_once( + target_tvd, + torch_tvd, + B, + T, + V, + dtype, + atol, + rtol, + reduction, + is_last_layer=True, + device="cuda", +): + torch.manual_seed(0) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + output = target_tvd(x1, target) + output2 = torch_tvd(x2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + if not is_last_layer: + output = output * 2.0 + output2 = output2 * 2.0 + + if reduction == "none": + return + + output.backward() + output2.backward() + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness(B, T, V, reduction, dtype, atol, rtol): + liger_tvd = LigerTVDLoss(reduction=reduction) + torch_tvd = TorchTVDLoss(reduction=reduction) + _test_correctness_once( + liger_tvd, torch_tvd, B, T, V, dtype, atol, rtol, reduction + ) + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness_not_last(B, T, V, reduction, dtype, atol, rtol): + liger_tvd = LigerTVDLoss(reduction=reduction) + torch_tvd = TorchTVDLoss(reduction=reduction) + _test_correctness_once( + liger_tvd, + torch_tvd, + B, + T, + V, + dtype, + atol, + rtol, + reduction, + is_last_layer=False, + ) From 7736e3246f45e9453d65b61175baecca62aa9c17 Mon Sep 17 00:00:00 2001 From: saurabhkoshatwar <35650601+saurabhkoshatwar@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:50:18 -0700 Subject: [PATCH 2/8] Add TVD to README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ddedb790..a8733ea60 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,7 @@ loss.backward() | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | | JSD | `liger_kernel.transformers.LigerJSD` | | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` | +| TVD | `liger_kernel.transformers.LigerTVDLoss` | - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup. @@ -278,7 +279,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size. - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. - +- **TVD**: [TVD](https://aclanthology.org/2023.acl-long.605.pdf) (Total variation distance), is implemented by computing both the loss and gradient in the forward pass. It achieves ~2X speed and ~15% memory reduction for 128k vocab size. ### Experimental Kernels From a45f6ce4aa7b6e9e8989a4050b626cc3782b0520 Mon Sep 17 00:00:00 2001 From: Saurabh Date: Fri, 25 Oct 2024 19:53:02 -0700 Subject: [PATCH 3/8] checkstyle fixes --- benchmark/scripts/benchmark_tvd.py | 14 ++++---- src/liger_kernel/ops/tvd.py | 49 ++++++++++++++++------------ src/liger_kernel/transformers/tvd.py | 7 ++-- test/transformers/test_tvd.py | 29 ++++++++-------- 4 files changed, 56 insertions(+), 43 deletions(-) diff --git a/benchmark/scripts/benchmark_tvd.py b/benchmark/scripts/benchmark_tvd.py index 6dc931844..2e62fd6fa 100644 --- a/benchmark/scripts/benchmark_tvd.py +++ b/benchmark/scripts/benchmark_tvd.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn import triton from utils import ( QUANTILES, @@ -10,27 +9,28 @@ run_benchmarks, ) -from liger_kernel.transformers.tvd import LigerTVDLoss +from liger_kernel.transformers.tvd import LigerTVDLoss class TorchTVDLoss(torch.nn.Module): - def __init__(self, reduction='batchmean'): + def __init__(self, reduction="batchmean"): super(TorchTVDLoss, self).__init__() self.reduction = reduction def forward(self, p, q): tvd = torch.abs(p - q) / 2.0 - if self.reduction == 'mean': + if self.reduction == "mean": return torch.sum(tvd) / (p.size(0) * p.size(1)) - elif self.reduction == 'sum': + elif self.reduction == "sum": return torch.sum(tvd) - elif self.reduction == 'none': + elif self.reduction == "none": return tvd - elif self.reduction == 'batchmean': + elif self.reduction == "batchmean": return torch.sum(tvd) / p.size(0) else: raise ValueError("Invalid reduction type.") + S, E = 12, 18 diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py index 4982a8cbc..4c0df93ce 100644 --- a/src/liger_kernel/ops/tvd.py +++ b/src/liger_kernel/ops/tvd.py @@ -1,8 +1,9 @@ +from typing import Literal + import torch import triton import triton.language as tl -from typing import Literal from liger_kernel.ops.utils import ensure_contiguous MAX_FUSED_SIZE = 65536 // 4 @@ -21,6 +22,7 @@ "batchmean": _REDUCTION_MODE_BATCHMEAN.value, } + def get_num_warps(BLOCK_SIZE): num_warps = 4 if BLOCK_SIZE >= 32768: @@ -32,17 +34,18 @@ def get_num_warps(BLOCK_SIZE): return num_warps + @triton.jit def _tv_distance_kernel( - p_ptr, - p_stride, - q_ptr, - q_stride, - loss_ptr, - loss_stride, - grads_ptr, + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, grads_stride, - n_cols, + n_cols, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, ): @@ -65,7 +68,7 @@ def _tv_distance_kernel( # TVD(P || Q) = 0.5 * |P - Q| tv_loss = 0.5 * tl.abs(p - q) - grad_res = tl.where(p > q, 0.5, -0.5) + grad_res = tl.where(p > q, 0.5, -0.5) tl.store(grads_ptr + offsets, grad_res, mask=mask) @@ -77,19 +80,20 @@ def _tv_distance_kernel( if reduction != _REDUCTION_MODE_NONE: tl.store(loss_ptr, loss_sum) -def tv_distance_forward_triton(p, q, reduction): - BT, V = p.shape - + +def tv_distance_forward_triton(p, q, reduction): + BT, V = p.shape + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) num_warps = get_num_warps(BLOCK_SIZE) - grid = (BT,) - + grid = (BT,) + reduction = _str_to_reduction_mode[reduction] out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) - grads = torch.empty_like(p) + grads = torch.empty_like(p) _tv_distance_kernel[grid]( p, @@ -105,7 +109,7 @@ def tv_distance_forward_triton(p, q, reduction): num_warps=num_warps, reduction=reduction, ) - + if reduction == _REDUCTION_MODE_BATCHMEAN.value: return output_tensor.sum() / BT, grads elif reduction == _REDUCTION_MODE_SUM.value: @@ -115,6 +119,7 @@ def tv_distance_forward_triton(p, q, reduction): else: return output_tensor, grads + def tvd_backward_triton(grad_output, grads): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. @@ -123,6 +128,7 @@ def tvd_backward_triton(grad_output, grads): return grads * grad_output + class LigerTVDLossFunction(torch.autograd.Function): """ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. @@ -131,7 +137,10 @@ class LigerTVDLossFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward( - ctx, p: torch.Tensor, q: torch.Tensor, reduction: REDUCTION_LITERAL = "batchmean" + ctx, + p: torch.Tensor, + q: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", ) -> torch.Tensor: """A forward pass for the Total Variation Distance Loss. @@ -161,7 +170,7 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: Returns: tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. """ - grads, = ctx.saved_tensors + (grads,) = ctx.saved_tensors BT, V = grads.shape grads = tvd_backward_triton(grad_output, grads) @@ -169,6 +178,6 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: if ctx.reduction == "batchmean": grads /= BT elif ctx.reduction == "mean": - grads /= (BT * V) + grads /= BT * V return grads, None, None diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py index 9ee0b4908..f226ee266 100644 --- a/src/liger_kernel/transformers/tvd.py +++ b/src/liger_kernel/transformers/tvd.py @@ -1,11 +1,12 @@ -import torch import torch.nn as nn + from liger_kernel.ops.tvd import LigerTVDLossFunction + class LigerTVDLoss(nn.Module): - def __init__(self, reduction='batchmean'): + def __init__(self, reduction="batchmean"): super(LigerTVDLoss, self).__init__() - self.reduction = reduction + self.reduction = reduction def forward(self, p, q): return LigerTVDLossFunction.apply(p, q, self.reduction) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py index 7ef0d9c89..23f4bf00c 100644 --- a/test/transformers/test_tvd.py +++ b/test/transformers/test_tvd.py @@ -2,25 +2,27 @@ import pytest import torch -from liger_kernel.transformers.tvd import LigerTVDLoss + +from liger_kernel.transformers.tvd import LigerTVDLoss + class TorchTVDLoss(torch.nn.Module): - def __init__(self, reduction='batchmean'): + def __init__(self, reduction="batchmean"): super(TorchTVDLoss, self).__init__() self.reduction = reduction def forward(self, p, q): tvd = torch.abs(p - q) / 2.0 - - if self.reduction == 'mean': - return torch.sum(tvd) /(p.size(0) * p.size(1)) - elif self.reduction == 'sum': + + if self.reduction == "mean": + return torch.sum(tvd) / (p.size(0) * p.size(1)) + elif self.reduction == "sum": return torch.sum(tvd) - elif self.reduction == 'none': + elif self.reduction == "none": return tvd - elif self.reduction == 'batchmean': - return torch.sum(tvd) / p.size(0) + elif self.reduction == "batchmean": + return torch.sum(tvd) / p.size(0) else: raise ValueError("Invalid reduction type.") @@ -61,6 +63,7 @@ def forward(self, p, q): ], ) + def _test_correctness_once( target_tvd, torch_tvd, @@ -85,7 +88,7 @@ def _test_correctness_once( output = target_tvd(x1, target) output2 = torch_tvd(x2, target) - + assert torch.allclose(output, output2, atol=atol, rtol=rtol) if not is_last_layer: @@ -99,15 +102,15 @@ def _test_correctness_once( output2.backward() assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness(B, T, V, reduction, dtype, atol, rtol): liger_tvd = LigerTVDLoss(reduction=reduction) torch_tvd = TorchTVDLoss(reduction=reduction) - _test_correctness_once( - liger_tvd, torch_tvd, B, T, V, dtype, atol, rtol, reduction - ) + _test_correctness_once(liger_tvd, torch_tvd, B, T, V, dtype, atol, rtol, reduction) + @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) From 18fa8b73f2ab0ed1372ca5cd34b02771321f7bdc Mon Sep 17 00:00:00 2001 From: Saurabh Date: Sat, 9 Nov 2024 17:29:46 -0800 Subject: [PATCH 4/8] init and backward pass reduction --- src/liger_kernel/ops/tvd.py | 11 ++--------- src/liger_kernel/transformers/__init__.py | 1 + 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py index 4c0df93ce..1099a3ec4 100644 --- a/src/liger_kernel/ops/tvd.py +++ b/src/liger_kernel/ops/tvd.py @@ -111,11 +111,11 @@ def tv_distance_forward_triton(p, q, reduction): ) if reduction == _REDUCTION_MODE_BATCHMEAN.value: - return output_tensor.sum() / BT, grads + return output_tensor.sum() / BT, grads / BT elif reduction == _REDUCTION_MODE_SUM.value: return output_tensor.sum(dim=0), grads elif reduction == _REDUCTION_MODE_MEAN.value: - return output_tensor.sum() / (BT * V), grads + return output_tensor.sum() / (BT * V), grads / (BT * V) else: return output_tensor, grads @@ -155,7 +155,6 @@ def forward( """ loss, grads = tv_distance_forward_triton(p, q, reduction) ctx.save_for_backward(grads) - ctx.reduction = reduction return loss @staticmethod @@ -171,13 +170,7 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. """ (grads,) = ctx.saved_tensors - BT, V = grads.shape grads = tvd_backward_triton(grad_output, grads) - if ctx.reduction == "batchmean": - grads /= BT - elif ctx.reduction == "mean": - grads /= BT * V - return grads, None, None diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index ffb8235cc..7a8d4feea 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -8,6 +8,7 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 +from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import ( # noqa: F401 _apply_liger_kernel, From de67f33cc3eff71699fe0cbf7f3b6e27e5a352a4 Mon Sep 17 00:00:00 2001 From: saurabhkoshatwar <35650601+saurabhkoshatwar@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:25:57 -0800 Subject: [PATCH 5/8] Add ignore index (#2) * Add ignore index --- src/liger_kernel/ops/tvd.py | 51 ++++++++++++++++----- src/liger_kernel/transformers/tvd.py | 7 +-- test/transformers/test_tvd.py | 66 +++++++++++++++++++++++++--- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py index 1099a3ec4..cbcef30ff 100644 --- a/src/liger_kernel/ops/tvd.py +++ b/src/liger_kernel/ops/tvd.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional import torch import triton @@ -45,8 +45,11 @@ def _tv_distance_kernel( loss_stride, grads_ptr, grads_stride, + label_ptr, + ignore_index: tl.constexpr, n_cols, BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, ): pid = tl.program_id(0).to(tl.int64) @@ -54,9 +57,21 @@ def _tv_distance_kernel( q_ptr += pid * q_stride loss_ptr += pid * loss_stride grads_ptr += pid * grads_stride + label_ptr += pid base_offsets = tl.arange(0, BLOCK_SIZE) + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_ptr + offsets, 0.0, mask=mask) + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, 0.0, mask=mask) + return + loss_sum = 0.0 for i in range(0, n_cols, BLOCK_SIZE): offsets = i + base_offsets @@ -81,7 +96,7 @@ def _tv_distance_kernel( tl.store(loss_ptr, loss_sum) -def tv_distance_forward_triton(p, q, reduction): +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): BT, V = p.shape BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) @@ -95,6 +110,8 @@ def tv_distance_forward_triton(p, q, reduction): output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) grads = torch.empty_like(p) + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + _tv_distance_kernel[grid]( p, p.stride(0), @@ -104,24 +121,26 @@ def tv_distance_forward_triton(p, q, reduction): output_tensor.stride(0), grads, grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, V, BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL = has_label, num_warps=num_warps, reduction=reduction, ) if reduction == _REDUCTION_MODE_BATCHMEAN.value: - return output_tensor.sum() / BT, grads / BT + return output_tensor.sum() / n_non_ignore, grads / n_non_ignore elif reduction == _REDUCTION_MODE_SUM.value: return output_tensor.sum(dim=0), grads elif reduction == _REDUCTION_MODE_MEAN.value: - return output_tensor.sum() / (BT * V), grads / (BT * V) + return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V) else: return output_tensor, grads def tvd_backward_triton(grad_output, grads): - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): return grads @@ -140,7 +159,10 @@ def forward( ctx, p: torch.Tensor, q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: """A forward pass for the Total Variation Distance Loss. @@ -148,15 +170,25 @@ def forward( ctx: Torch autograd context p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. Returns: torch.Tensor: The computed Total Variation Distance Loss. """ - loss, grads = tv_distance_forward_triton(p, q, reduction) + has_label = False + if shift_labels is not None: + assert shift_labels.shape == ( + p.shape[0], + ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) ctx.save_for_backward(grads) return loss - + @staticmethod @ensure_contiguous def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: @@ -167,10 +199,9 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: grad_output (torch.Tensor): The gradient of the loss with respect to the output. Returns: - tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. """ (grads,) = ctx.saved_tensors - grads = tvd_backward_triton(grad_output, grads) - return grads, None, None + return grads, None, None, None, None diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py index f226ee266..45bf3e7e5 100644 --- a/src/liger_kernel/transformers/tvd.py +++ b/src/liger_kernel/transformers/tvd.py @@ -4,9 +4,10 @@ class LigerTVDLoss(nn.Module): - def __init__(self, reduction="batchmean"): + def __init__(self, reduction="batchmean", ignore_index: int = -100): super(LigerTVDLoss, self).__init__() self.reduction = reduction + self.ignore_index = ignore_index - def forward(self, p, q): - return LigerTVDLossFunction.apply(p, q, self.reduction) + def forward(self, p, q, shift_labels = None): + return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py index 23f4bf00c..59519ac95 100644 --- a/test/transformers/test_tvd.py +++ b/test/transformers/test_tvd.py @@ -1,4 +1,4 @@ -from test.utils import supports_bfloat16 +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch @@ -7,22 +7,30 @@ class TorchTVDLoss(torch.nn.Module): - def __init__(self, reduction="batchmean"): + def __init__(self, reduction="batchmean", ignore_index: int = -100): super(TorchTVDLoss, self).__init__() self.reduction = reduction + self.ignore_index = ignore_index - def forward(self, p, q): + def forward(self, p, q, label = + None): tvd = torch.abs(p - q) / 2.0 + n_non_ignore = p.size(0) + if label is not None: + tvd = torch.where(label.unsqueeze(1) != self.ignore_index, tvd, torch.zeros_like(tvd)) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + return torch.tensor(0.0).to(tvd.device) if self.reduction == "mean": - return torch.sum(tvd) / (p.size(0) * p.size(1)) + return torch.sum(tvd) / (n_non_ignore * p.size(1)) elif self.reduction == "sum": return torch.sum(tvd) elif self.reduction == "none": return tvd elif self.reduction == "batchmean": - return torch.sum(tvd) / p.size(0) + return torch.sum(tvd) / n_non_ignore else: raise ValueError("Invalid reduction type.") @@ -102,6 +110,45 @@ def _test_correctness_once( output2.backward() assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_ignore_index_once( + target_tvd, + torch_tvd, + ignore_index, + B, + T, + V, + dtype, + atol, + rtol, + reduction, + device="cuda" +): + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + label[indices_to_assign] = ignore_index + + output = torch_tvd(x1, target, label) + output2 = target_tvd(x2, target, label) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + if reduction == "none": + return + + output.backward() + output2.backward() + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @@ -130,3 +177,12 @@ def test_correctness_not_last(B, T, V, reduction, dtype, atol, rtol): reduction, is_last_layer=False, ) + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +@pytest.mark.parametrize("ignore_index", [-100, 0, 1]) +def test_correctness_with_ignore_index(B, T, V, reduction, dtype, atol, rtol, ignore_index): + liger_tvd = LigerTVDLoss(reduction=reduction, ignore_index=ignore_index) + torch_tvd = TorchTVDLoss(reduction=reduction, ignore_index=ignore_index) + _test_correctness_with_ignore_index_once(liger_tvd, torch_tvd, ignore_index, B, T, V, dtype, atol, rtol, reduction) From c220256b82b03ea0dfe7344dce02bc4c717bb04c Mon Sep 17 00:00:00 2001 From: Saurabh Date: Sat, 7 Dec 2024 18:58:37 -0800 Subject: [PATCH 6/8] lowest rtol --- test/transformers/test_tvd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py index 23f4bf00c..c75e83e48 100644 --- a/test/transformers/test_tvd.py +++ b/test/transformers/test_tvd.py @@ -53,7 +53,7 @@ def forward(self, p, q): pytest.param( torch.bfloat16, 1e-8, - 5e-2, + 1e-6, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), From 7a2d6e83ba364954f689908ff9baaa26ad1a7958 Mon Sep 17 00:00:00 2001 From: Saurabh Date: Sat, 7 Dec 2024 19:33:06 -0800 Subject: [PATCH 7/8] checkstyle fixes --- src/liger_kernel/ops/tvd.py | 11 ++++--- src/liger_kernel/transformers/__init__.py | 2 +- src/liger_kernel/transformers/tvd.py | 6 ++-- test/transformers/test_tvd.py | 39 +++++++++++++---------- 4 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py index cbcef30ff..00825755e 100644 --- a/src/liger_kernel/ops/tvd.py +++ b/src/liger_kernel/ops/tvd.py @@ -57,7 +57,7 @@ def _tv_distance_kernel( q_ptr += pid * q_stride loss_ptr += pid * loss_stride grads_ptr += pid * grads_stride - label_ptr += pid + label_ptr += pid base_offsets = tl.arange(0, BLOCK_SIZE) @@ -125,7 +125,7 @@ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_ ignore_index, V, BLOCK_SIZE=BLOCK_SIZE, - HAS_LABEL = has_label, + HAS_LABEL=has_label, num_warps=num_warps, reduction=reduction, ) @@ -162,7 +162,6 @@ def forward( shift_labels: Optional[torch.Tensor] = None, reduction: REDUCTION_LITERAL = "batchmean", ignore_index: int = -100, - ) -> torch.Tensor: """A forward pass for the Total Variation Distance Loss. @@ -185,10 +184,12 @@ def forward( shift_labels = shift_labels.contiguous() has_label = True - loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + loss, grads = tv_distance_forward_triton( + p, q, shift_labels, reduction, ignore_index, has_label + ) ctx.save_for_backward(grads) return loss - + @staticmethod @ensure_contiguous def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 7a8d4feea..28d3c9a80 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -8,7 +8,6 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 -from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import ( # noqa: F401 _apply_liger_kernel, @@ -30,3 +29,4 @@ LigerPhi3SwiGLUMLP, LigerSwiGLUMLP, ) +from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py index 45bf3e7e5..016c57748 100644 --- a/src/liger_kernel/transformers/tvd.py +++ b/src/liger_kernel/transformers/tvd.py @@ -9,5 +9,7 @@ def __init__(self, reduction="batchmean", ignore_index: int = -100): self.reduction = reduction self.ignore_index = ignore_index - def forward(self, p, q, shift_labels = None): - return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index) + def forward(self, p, q, shift_labels=None): + return LigerTVDLossFunction.apply( + p, q, shift_labels, self.reduction, self.ignore_index + ) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py index 79922a9a7..062bf8ffc 100644 --- a/test/transformers/test_tvd.py +++ b/test/transformers/test_tvd.py @@ -12,13 +12,14 @@ def __init__(self, reduction="batchmean", ignore_index: int = -100): self.reduction = reduction self.ignore_index = ignore_index - def forward(self, p, q, label = - None): + def forward(self, p, q, label=None): tvd = torch.abs(p - q) / 2.0 n_non_ignore = p.size(0) if label is not None: - tvd = torch.where(label.unsqueeze(1) != self.ignore_index, tvd, torch.zeros_like(tvd)) + tvd = torch.where( + label.unsqueeze(1) != self.ignore_index, tvd, torch.zeros_like(tvd) + ) n_non_ignore = (label != self.ignore_index).sum().item() if n_non_ignore == 0: return torch.tensor(0.0).to(tvd.device) @@ -110,18 +111,19 @@ def _test_correctness_once( output2.backward() assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + def _test_correctness_with_ignore_index_once( - target_tvd, - torch_tvd, - ignore_index, - B, - T, - V, - dtype, - atol, - rtol, + target_tvd, + torch_tvd, + ignore_index, + B, + T, + V, + dtype, + atol, + rtol, reduction, - device="cuda" + device="cuda", ): input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True) @@ -144,7 +146,7 @@ def _test_correctness_with_ignore_index_once( if reduction == "none": return - + output.backward() output2.backward() assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) @@ -178,11 +180,16 @@ def test_correctness_not_last(B, T, V, reduction, dtype, atol, rtol): is_last_layer=False, ) + @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize("ignore_index", [-100, 0, 1]) -def test_correctness_with_ignore_index(B, T, V, reduction, dtype, atol, rtol, ignore_index): +def test_correctness_with_ignore_index( + B, T, V, reduction, dtype, atol, rtol, ignore_index +): liger_tvd = LigerTVDLoss(reduction=reduction, ignore_index=ignore_index) torch_tvd = TorchTVDLoss(reduction=reduction, ignore_index=ignore_index) - _test_correctness_with_ignore_index_once(liger_tvd, torch_tvd, ignore_index, B, T, V, dtype, atol, rtol, reduction) + _test_correctness_with_ignore_index_once( + liger_tvd, torch_tvd, ignore_index, B, T, V, dtype, atol, rtol, reduction + ) From 2a89121be36f302dc1b1c6b6a376f64256dbc48e Mon Sep 17 00:00:00 2001 From: Saurabh Date: Sat, 7 Dec 2024 19:38:39 -0800 Subject: [PATCH 8/8] checkstyle fixes --- test/transformers/test_tvd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py index 062bf8ffc..bd7ff0aed 100644 --- a/test/transformers/test_tvd.py +++ b/test/transformers/test_tvd.py @@ -1,4 +1,4 @@ -from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 +from test.utils import supports_bfloat16 import pytest import torch