Skip to content

Commit

Permalink
Adding comparison for different fp8 matmuls
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Oct 1, 2024
1 parent 83f754e commit f7d68be
Show file tree
Hide file tree
Showing 3 changed files with 1,064 additions and 0 deletions.
276 changes: 276 additions & 0 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import itertools
from dataclasses import dataclass
from typing import List, Optional
import torch
from tabulate import tabulate
from tqdm import tqdm
from jsonargparse import CLI
from pathlib import Path
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
from enum import Enum
import csv

torch._dynamo.config.cache_size_limit = 1000


class FP8Kernel(Enum):
PERSISTENT = "Persistent"
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"


class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"


def is_col_major(stride):
assert len(stride) == 2, "is_col_major only supports 2D tensors"
return stride[1] > stride[0] and stride[0] == 1


def get_fp8_matmul(
A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))

if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")
if fp8_kernel == FP8Kernel.PERSISTENT:
return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.PERSISTENT_TMA:
return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.DEVICE_TMA:
return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.SCALED_MM:
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
else:
raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}")


@dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool


@dataclass(frozen=True)
class ExperimentResult:
bf16_time: float
fp8_time: float
bf16_tflops: float
fp8_tflops: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def calculate_tflops(M: int, N: int, K: int, time_us: float) -> float:
"""Calculate TFLOPS (Tera Floating Point Operations Per Second)"""
flops = 2 * M * N * K # Number of floating point operations for matrix multiplication
tflops = (flops / time_us) / 1e6 # Convert to TFLOPS
return tflops


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

bf16_matmul = lambda x, y: torch.matmul(x, y)
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)

if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
bf16_matmul = torch.compile(bf16_matmul)
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune")

# Warmup phase
warmup_iterations = 5
for _ in range(warmup_iterations):
_ = bf16_matmul(A, B)
_ = fp8_matmul()
torch.cuda.synchronize()

# Actual benchmarking
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)

# Calculate TFLOPS
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time)
fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time)

# Baseline fp8_matmul correctness
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
out_base = scaled_mm_base()
out = fp8_matmul()
# Failing on one sample with large N
torch.testing.assert_close(out, out_base)

return ExperimentResult(
bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops
)


def print_results(experiments: List[Experiment], save_path: Optional[Path] = None):
headers = [
"M",
"K",
"N",
"Scaling Strategy",
"Fp8 Kernel",
"Compiled",
"BF16 Time (ms)",
"FP8 Time (ms)",
"Speedup",
"BF16 TFLOPS",
"FP8 TFLOPS",
"TFLOPS Ratio",
]
rows = []
for experiment in experiments:
config = experiment.config
result = experiment.result
speedup = result.bf16_time / result.fp8_time
tflops_ratio = result.fp8_tflops / result.bf16_tflops
rows.append(
[
config.M,
config.K,
config.N,
config.scaling_strategy,
config.fp8_kernel,
config.compile,
f"{result.bf16_time:.4f}",
f"{result.fp8_time:.4f}",
f"{speedup:.2f}x",
f"{result.bf16_tflops:.2f}",
f"{result.fp8_tflops:.2f}",
f"{tflops_ratio:.2f}x",
]
)
print(tabulate(rows, headers=headers, floatfmt=".4f"))

if save_path is not None:
with open(save_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(headers)
writer.writerows(rows)
print(f"💾 Results saved to: {save_path}")


def get_configs_varying_k() -> List[ExperimentConfig]:
shapes = [(8192, K, 8192) for K in range(512, 8193, 512)]
scaling_strategies = [ScalingStrategy.PER_ROW]
compile_options = [False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
FP8Kernel.PERSISTENT,
FP8Kernel.PERSISTENT_TMA,
FP8Kernel.DEVICE_TMA,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
shapes, scaling_strategies, compile_options, fp8_kernels
):
configs.append(
ExperimentConfig(
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
)
)
return configs


def get_configs() -> List[ExperimentConfig]:
shapes = [
(8192, 512, 8192),
(1024, 20512, 2048),
(196608, 656, 256),
(1024, 1024, 137568),
(1024, 256, 137568),
(1024, 1568, 1664),
(1024, 1024, 43040),
(1024, 43040, 2048),
(1024, 1568, 1920),
(1024, 1024, 20512),
(137216, 1920, 384),
(1024, 5632, 98304),
(196608, 656, 624),
(196608, 3600, 624),
(1024, 1568, 640),
(1024, 192, 98304),
(1024, 1568, 1024),
(1024, 137568, 2048),
# (1024, 24576, 98304), Fails on for a large N values
(1024, 1024, 1024),
]

scaling_strategies = [ScalingStrategy.PER_TENSOR]
compile_options = [False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
FP8Kernel.PERSISTENT,
FP8Kernel.PERSISTENT_TMA,
FP8Kernel.DEVICE_TMA,
]
for (M, K, N), strategy, compile, kernel in itertools.product(
shapes, scaling_strategies, compile_options, fp8_kernels
):
configs.append(
ExperimentConfig(
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
)
)
return configs


def main(save_path: Optional[str] = None):
"""Benchmark FP8 MatMul with different configurations.
Args:
save_path (Optional[str], optional): Path to save the results. Defaults to None.
"""
torch.random.manual_seed(123)
configs = get_configs_varying_k()
results = []
if save_path is not None:
save_path = Path(save_path)
save_path = save_path.with_suffix(".csv")
save_path.parent.mkdir(parents=True, exist_ok=True)
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))
print_results(results, save_path)


if __name__ == "__main__":
CLI(main)
Loading

0 comments on commit f7d68be

Please sign in to comment.