Skip to content

Commit

Permalink
Adding comparison for different fp8 matmuls
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Sep 27, 2024
1 parent 83f754e commit a9ace20
Show file tree
Hide file tree
Showing 2 changed files with 695 additions and 0 deletions.
182 changes: 182 additions & 0 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import itertools
from dataclasses import dataclass
from typing import List
import torch
from tabulate import tabulate
from tqdm import tqdm
from jsonargparse import CLI
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)

torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: str
fp8_kernel: str
compile: bool


@dataclass(frozen=True)
class ExperimentResult:
bf16_time: float
fp8_time: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_fp8_matmul(A, B, scaling_strategies, fp8_kernel):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
if scaling_strategies == "PerTensor":
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
if fp8_kernel == "Persistent":
return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
if fp8_kernel == "Persistent-TMA":
return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
if fp8_kernel == "Device-TMA":
return lambda: matmul_device_tma_persistent(
A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16
)
return lambda: addmm_float8_unwrapped_inference(
A_fp8,
a_scale,
B_fp8,
b_scale,
output_dtype=torch.bfloat16,
use_fast_accum=True,
)
elif scaling_strategies == "PerRow":
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
return lambda: addmm_float8_unwrapped_inference(
A_fp8,
a_scale,
B_fp8,
b_scale,
output_dtype=torch.bfloat16,
use_fast_accum=True,
)
else:
raise ValueError("Invalid scaling strategy")


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)

if config.compile:
bf16_matmul = torch.compile(lambda x, y: torch.matmul(x, y))
fp8_matmul = torch.compile(fp8_matmul)
else:
bf16_matmul = lambda x, y: torch.matmul(x, y)

# Warmup phase
warmup_iterations = 5
for _ in range(warmup_iterations):
_ = bf16_matmul(A, B)
_ = fp8_matmul()
torch.cuda.synchronize()

# Actual benchmarking
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)

return ExperimentResult(bf16_time=bf16_time, fp8_time=fp8_time)


def print_results(experiments: List[Experiment]):
headers = [
"M",
"K",
"N",
"Scaling Strategy",
"Fp8 Kernel",
"Compiled",
"BF16 Time (ms)",
"FP8 Time (ms)",
"Speedup",
]
rows = []
for experiment in experiments:
config = experiment.config
result = experiment.result
speedup = result.bf16_time / result.fp8_time
rows.append(
[
config.M,
config.K,
config.N,
config.scaling_strategy,
config.fp8_kernel,
config.compile,
f"{result.bf16_time:.4f}",
f"{result.fp8_time:.4f}",
f"{speedup:.2f}x",
]
)
print(tabulate(rows, headers=headers, floatfmt=".4f"))


def get_configs() -> List[ExperimentConfig]:
shapes = [
(1, 2048, 4096),
# (1, 2048, 8192),
# (1, 2048, 16384),
# (1, 4096, 4096),
# (1, 4096, 8192),
# (1, 4096, 16384),
# (4, 2048, 4096),
# (4, 2048, 8192),
# (16, 512, 4096),
(512, 2048, 8192),
]
scaling_strategies = ["PerTensor"]
compile_options = [False]
configs = []
fp8_kernels = ["Persistent", "Scaled_MM", "Persistent-TMA", "Device-TMA"]
for (M, K, N), strategy, compile, kernel in itertools.product(
shapes, scaling_strategies, compile_options, fp8_kernels
):
configs.append(
ExperimentConfig(
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
)
)
return configs


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))
print_results(results)


if __name__ == "__main__":
CLI(main)
Loading

0 comments on commit a9ace20

Please sign in to comment.