Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move float8_experimental to torchao/float8 #551

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ In some cases we rewrote popular GenAI models to be significantly faster in nati

### Training

#### Float8

[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanna mention a topline speedup marketing number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our last public number is from 2023H2, we plan to release new speedups in ~weeks but not ready yet. Will add it here when it's posted.


#### Sparsity

We've added support for semi-structured 2:4 sparsity with 6% end to end speedups on ViT-L

The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
Expand Down
307 changes: 307 additions & 0 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import pandas as pd

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

# prevent splitting columns when printing a data frame
pd.set_option("display.expand_frame_repr", False)
# print the entire data frame
pd_print_full_ctx = pd.option_context(
"display.max_rows", None, "display.max_columns", None
)


def benchmark_torch_function_in_microseconds(
func: Callable,
*args,
**kwargs,
) -> float:
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.blocked_autorange().median * 1e6


@dataclass
class Experiment:
name: str
shape: Tuple[int, int, int]
ref_time_sec: float
float8_time_sec: float
dtype: torch.dtype
compiled: bool
use_fast_accum: bool
scaling_repr: str

# 3 Times since we are calculating forward backward
@property
def ref_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.ref_time_sec

@property
def ref_pct_top_peak(self):
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]

@property
def float8_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.float8_time_sec

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]


def main(
sweep_path: Optional[Path] = None,
compile: bool = True,
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
input_bias = False
if fast_accum_filter is not None:
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
experiment_list: List[Experiment] = []
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
else:
linear_float8.forward_config = ScaledMMConfig(False, False, False)

bsz, seq_len = 4, 4096
M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
for _ in range(n):
fn(*args, **kwargs)

return wrapper

REPEAT_N = 100

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)

if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)

for _ in range(5):
ref_forw_backward()
float8_forw_backward()

ref_time = (
benchmark_torch_function_in_microseconds(ref_forw_backward)
* 1e-6
/ REPEAT_N
)
float8_time = (
benchmark_torch_function_in_microseconds(float8_forw_backward)
* 1e-6
/ REPEAT_N
)
experiment = Experiment(
name,
(M, K, N),
ref_time,
float8_time,
dtype,
compile,
use_fast_accum=fast_accum,
scaling_repr=scaling_repr,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
experiment_list.append(experiment)
torch._dynamo.reset()

headers = [
"name",
"M",
"K",
"N",
"scaling_repr",
"ref_dtype",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
"pt_fp8_tops_sec",
"pt_fp8_pct_top_peak",
]
data = []
for experiment in experiment_list:
data.append(
[
experiment.name,
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.scaling_repr,
experiment.dtype,
experiment.compiled,
experiment.use_fast_accum,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
experiment.float8_tops_sec,
experiment.float8_pct_top_peak,
]
)

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
data_pd["shape"] = (
"("
+ data_pd["M"].astype(str)
+ ", "
+ data_pd["K"].astype(str)
+ ", "
+ data_pd["N"].astype(str)
+ ")"
)

data_pd_simple = data_pd[
[
"name",
"shape",
"scaling_repr",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
]
]
with pd_print_full_ctx:
print(data_pd_simple)

if sweep_path is not None:
sweep_path = sweep_path.with_suffix(".csv")
data_pd.to_csv(sweep_path)


def invoke_main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_path", type=str, required=False)
parser.add_argument("--disable_compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.scaling_type_input is not None:
kwargs["scaling_type_input"] = args.scaling_type_input
if args.scaling_type_weight is not None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
main(
output_path,
not args.disable_compile,
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
**kwargs,
)


if __name__ == "__main__":
invoke_main() # pragma: no cover
Loading
Loading