Skip to content

Commit

Permalink
Add benchmarks for JSD loss
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <austin362667@gmail.com>
  • Loading branch information
austin362667 authored and shivam15s committed Dec 17, 2024
1 parent 82c80fe commit 7b22ac7
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
24 changes: 24 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,27 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
270 changes: 270 additions & 0 deletions benchmark/scripts/benchmark_distill_jsd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.utils import infer_device

device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
from test.chunked_loss.test_jsd_loss import HFJSDLoss

super().__init__()
self.student_lin = torch.nn.Linear(
in_features=H // 2, out_features=V, bias=bias, dtype=dtype
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.jsd_loss = HFJSDLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics

def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
)


class LigerJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
super().__init__()
self.student_lin = torch.nn.Linear(
in_features=H // 2, out_features=V, bias=bias, dtype=dtype
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.jsd_loss = LigerFusedLinearJSDFunction.apply

def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
self.weight_hard_loss,
self.weight_soft_loss,
)


def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)

_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)

teacher_input = torch.rand(BT, H, device=device, dtype=dtype)

target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)

def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)

_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)

teacher_input = torch.rand(BT, H, device=device, dtype=dtype)

target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)

def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[student_input1, student_input2],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "distill_jsd_loss",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": False,
"weight_hard_loss": 0.5,
"weight_soft_loss": 0.5,
"ignore_index": -100,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_jsd_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)

run_benchmarks(
bench_test_fn=bench_memory_jsd_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)

0 comments on commit 7b22ac7

Please sign in to comment.