Skip to content

Commit

Permalink
Add benchmark fp8 script (#9)
Browse files Browse the repository at this point in the history
* working

* add torch_compile reference

* compile is annoying to microbenchmark, skipping for now
  • Loading branch information
drisspg committed Dec 18, 2023
1 parent 1b9af50 commit 620770b
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 12 deletions.
137 changes: 137 additions & 0 deletions benchmarks/fp8_sat_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import itertools
from dataclasses import dataclass

from typing import List

import torch

from tabulate import tabulate
from tqdm import tqdm

from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant
from transformer_nuggets.utils import benchmark_torch_function_in_microseconds

device = torch.device("cuda")


@dataclass(frozen=True)
class ExperimentConfig:
numel: int
high_precision_dtype: torch.dtype
low_precision_dtype: torch.dtype
saturated: bool = False


@dataclass(frozen=True)
class ExperimentResult:
triton_time: float
pytorch_time: float
compiled_pytorch_time: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
sizes = [2**21, 2**22, 2**23, 2**24]
high_precision_dtypes = [torch.bfloat16, torch.float32]
low_precision_dtypes = [torch.float8_e4m3fn, torch.float8_e5m2]
saturated = [True, False]
configs = []
for size, high_precision_dtype, low_precision_dtype, sat in itertools.product(
sizes, high_precision_dtypes, low_precision_dtypes, saturated
):
configs.append(
ExperimentConfig(
numel=size,
high_precision_dtype=high_precision_dtype,
low_precision_dtype=low_precision_dtype,
saturated=sat,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
high_precision_tensor = torch.randn(
config.numel, dtype=config.high_precision_dtype, device=device
)
triton_hp_tensor = high_precision_tensor.clone()

eager_abs_max = torch.empty(1, dtype=torch.float32, device=device)
triton_abs_max = torch.empty(1, dtype=torch.float32, device=device)

scale = torch.rand(1, dtype=torch.float32, device=device)

triton_time = benchmark_torch_function_in_microseconds(
scaled_quant,
triton_hp_tensor,
scale,
triton_abs_max,
config.low_precision_dtype,
config.saturated,
)
pytorch_time = benchmark_torch_function_in_microseconds(
eager_scaled_quant,
high_precision_tensor,
scale,
eager_abs_max,
config.low_precision_dtype,
config.saturated,
)
# compiled_pytorch_fn = torch.compile(eager_scaled_quant, fullgraph=True)
# compiled_pytorch_time = benchmark_torch_function_in_microseconds(
# compiled_pytorch_fn,
# high_precision_tensor,
# scale,
# eager_abs_max,
# config.low_precision_dtype,
# config.saturated,
# )
return ExperimentResult(
triton_time=triton_time, pytorch_time=pytorch_time, compiled_pytorch_time=0
)


def print_results(experiments: List[Experiment]):
headers = [
"numel",
"high_precision_dtype",
"low_precision_dtype",
"saturated",
"triton_time",
"pytorch_time",
"compiled_pytorch_time",
]
rows = []
for experiment in experiments:
rows.append(
[
experiment.config.numel,
experiment.config.high_precision_dtype,
experiment.config.low_precision_dtype,
experiment.config.saturated,
experiment.result.triton_time,
experiment.result.pytorch_time,
experiment.result.compiled_pytorch_time,
]
)
print(tabulate(rows, headers=headers))


def main():
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
15 changes: 3 additions & 12 deletions test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import pytest
import torch

from transformer_nuggets.fp8.scaled_quant import scaled_quant


def eager_scaled_quant(a, scale, abs_max, fp8_dtype, saturated=False):
out = a * scale
if saturated:
out = torch.where(out > torch.finfo(fp8_dtype).max, torch.finfo(fp8_dtype).max, out)
out = torch.where(
out < -1 * torch.finfo(fp8_dtype).max, -1 * torch.finfo(fp8_dtype).max, out
)
return out.to(fp8_dtype), torch.max(torch.abs(out))
from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant


@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
Expand All @@ -35,7 +25,8 @@ def test_saturated(fp8_dtype):
scale = torch.tensor([4.0], dtype=torch.float32, device="cuda")
abs_max = torch.tensor([-1.0], dtype=torch.float32, device="cuda")
output = scaled_quant(a, scale, abs_max, fp8_dtype, saturated=True)
eager_output, eager_abs_max = eager_scaled_quant(a, scale, abs_max, fp8_dtype, saturated=True)
eager_abs_max = torch.clone(abs_max)
eager_output = eager_scaled_quant(a, scale, eager_abs_max, fp8_dtype, saturated=True)
torch.testing.assert_close(output, eager_output)
torch.testing.assert_close(
abs_max,
Expand Down
28 changes: 28 additions & 0 deletions transformer_nuggets/fp8/scaled_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def scaled_cast(
"""Quantize tensor to fp8 using a delayed scaled and calculate abs_max"""
offset = tl.program_id(0) * XBLOCK
index = offset + tl.arange(0, XBLOCK)[:]
index = tl.max_contiguous(tl.multiple_of(index, XBLOCK), XBLOCK)
mask = index < numel
inpt = tl.load(inpt_ptr + (index), mask=mask)
block_max = tl.max(tl.abs(inpt))
Expand Down Expand Up @@ -64,3 +65,30 @@ def scaled_quant(
inpt_tensor, out_tensor, scale, abs_max, numel, 4096, tl_dtype, max_val, num_warps=8
)
return out_tensor


def eager_scaled_quant(
a: torch.Tensor,
scale: torch.Tensor,
abs_max: torch.Tensor,
fp8_dtype: torch.dtype,
saturated: torch.dtype = False,
):
"""Quantize tensor to fp8 using a delayed scaled and calculate abs_max
Args:
a: Input tensor to quantize
scale: Scale to apply to input tensor, calculated from previous abs_max
abs_max: Absolute maximum value of input tensor, will be updated
fp8_dtype: FP8 datatype to quantize to
saturated: Whether to saturate the output tensor to the maximum value
of the fp8 datatype
"""
out = a * scale
if saturated:
out = torch.where(out > torch.finfo(fp8_dtype).max, torch.finfo(fp8_dtype).max, out)
out = torch.where(
out < -1 * torch.finfo(fp8_dtype).max, -1 * torch.finfo(fp8_dtype).max, out
)
abs_max = torch.max(torch.abs(out))
return out.to(fp8_dtype)

0 comments on commit 620770b

Please sign in to comment.