-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding comparison for different fp8 matmuls
- Loading branch information
Showing
2 changed files
with
695 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import itertools | ||
from dataclasses import dataclass | ||
from typing import List | ||
import torch | ||
from tabulate import tabulate | ||
from tqdm import tqdm | ||
from jsonargparse import CLI | ||
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds | ||
from torchao.float8.inference import ( | ||
addmm_float8_unwrapped_inference, | ||
preprocess_data, | ||
Float8MMConfig, | ||
) | ||
from transformer_nuggets.fp8.fp8_matmul import ( | ||
matmul_persistent, | ||
matmul_tma_persistent, | ||
matmul_device_tma_persistent, | ||
) | ||
|
||
torch._dynamo.config.cache_size_limit = 1000 | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExperimentConfig: | ||
M: int | ||
K: int | ||
N: int | ||
scaling_strategy: str | ||
fp8_kernel: str | ||
compile: bool | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExperimentResult: | ||
bf16_time: float | ||
fp8_time: float | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Experiment: | ||
config: ExperimentConfig | ||
result: ExperimentResult | ||
|
||
|
||
def get_fp8_matmul(A, B, scaling_strategies, fp8_kernel): | ||
A_fp8 = A.to(torch.float8_e4m3fn) | ||
B_fp8 = B.to(torch.float8_e4m3fn) | ||
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True)) | ||
if scaling_strategies == "PerTensor": | ||
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32) | ||
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32) | ||
if fp8_kernel == "Persistent": | ||
return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) | ||
if fp8_kernel == "Persistent-TMA": | ||
return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) | ||
if fp8_kernel == "Device-TMA": | ||
return lambda: matmul_device_tma_persistent( | ||
A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16 | ||
) | ||
return lambda: addmm_float8_unwrapped_inference( | ||
A_fp8, | ||
a_scale, | ||
B_fp8, | ||
b_scale, | ||
output_dtype=torch.bfloat16, | ||
use_fast_accum=True, | ||
) | ||
elif scaling_strategies == "PerRow": | ||
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32) | ||
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T | ||
return lambda: addmm_float8_unwrapped_inference( | ||
A_fp8, | ||
a_scale, | ||
B_fp8, | ||
b_scale, | ||
output_dtype=torch.bfloat16, | ||
use_fast_accum=True, | ||
) | ||
else: | ||
raise ValueError("Invalid scaling strategy") | ||
|
||
|
||
def run_experiment(config: ExperimentConfig) -> ExperimentResult: | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16) | ||
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16) | ||
|
||
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel) | ||
|
||
if config.compile: | ||
bf16_matmul = torch.compile(lambda x, y: torch.matmul(x, y)) | ||
fp8_matmul = torch.compile(fp8_matmul) | ||
else: | ||
bf16_matmul = lambda x, y: torch.matmul(x, y) | ||
|
||
# Warmup phase | ||
warmup_iterations = 5 | ||
for _ in range(warmup_iterations): | ||
_ = bf16_matmul(A, B) | ||
_ = fp8_matmul() | ||
torch.cuda.synchronize() | ||
|
||
# Actual benchmarking | ||
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) | ||
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul) | ||
|
||
return ExperimentResult(bf16_time=bf16_time, fp8_time=fp8_time) | ||
|
||
|
||
def print_results(experiments: List[Experiment]): | ||
headers = [ | ||
"M", | ||
"K", | ||
"N", | ||
"Scaling Strategy", | ||
"Fp8 Kernel", | ||
"Compiled", | ||
"BF16 Time (ms)", | ||
"FP8 Time (ms)", | ||
"Speedup", | ||
] | ||
rows = [] | ||
for experiment in experiments: | ||
config = experiment.config | ||
result = experiment.result | ||
speedup = result.bf16_time / result.fp8_time | ||
rows.append( | ||
[ | ||
config.M, | ||
config.K, | ||
config.N, | ||
config.scaling_strategy, | ||
config.fp8_kernel, | ||
config.compile, | ||
f"{result.bf16_time:.4f}", | ||
f"{result.fp8_time:.4f}", | ||
f"{speedup:.2f}x", | ||
] | ||
) | ||
print(tabulate(rows, headers=headers, floatfmt=".4f")) | ||
|
||
|
||
def get_configs() -> List[ExperimentConfig]: | ||
shapes = [ | ||
(1, 2048, 4096), | ||
# (1, 2048, 8192), | ||
# (1, 2048, 16384), | ||
# (1, 4096, 4096), | ||
# (1, 4096, 8192), | ||
# (1, 4096, 16384), | ||
# (4, 2048, 4096), | ||
# (4, 2048, 8192), | ||
# (16, 512, 4096), | ||
(512, 2048, 8192), | ||
] | ||
scaling_strategies = ["PerTensor"] | ||
compile_options = [False] | ||
configs = [] | ||
fp8_kernels = ["Persistent", "Scaled_MM", "Persistent-TMA", "Device-TMA"] | ||
for (M, K, N), strategy, compile, kernel in itertools.product( | ||
shapes, scaling_strategies, compile_options, fp8_kernels | ||
): | ||
configs.append( | ||
ExperimentConfig( | ||
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel | ||
) | ||
) | ||
return configs | ||
|
||
|
||
def main(): | ||
torch.random.manual_seed(123) | ||
configs = get_configs() | ||
results = [] | ||
for config in tqdm(configs): | ||
result = run_experiment(config) | ||
results.append(Experiment(config=config, result=result)) | ||
print_results(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
CLI(main) |
Oops, something went wrong.