diff --git a/README.md b/README.md
index e31dc63a8..8905b3f43 100644
--- a/README.md
+++ b/README.md
@@ -85,6 +85,12 @@ In some cases we rewrote popular GenAI models to be significantly faster in nati
### Training
+#### Float8
+
+[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
+
+#### Sparsity
+
We've added support for semi-structured 2:4 sparsity with 6% end to end speedups on ViT-L
The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py
new file mode 100644
index 000000000..b44d4f5dc
--- /dev/null
+++ b/benchmarks/float8/bench_linear_float8.py
@@ -0,0 +1,307 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import argparse
+import copy
+from dataclasses import dataclass
+from itertools import product
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple
+
+import pandas as pd
+
+import torch
+import torch.utils.benchmark as benchmark
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear import Float8Linear
+from torchao.float8.float8_linear_utils import (
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_tensor import ScaledMMConfig
+from tqdm import tqdm
+
+# estimating TOPs for matmuls in fp32, fp16, fp8
+# assuming A * B = C, with A being M * K, B being K * N, C being M * N
+
+# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
+h100_peak_flops_float32 = 67e12
+h100_peak_flops_fp16_tc = 1979e12
+h100_peak_tops_float8_tc = 3958e12
+
+dtype_to_peak_tops = {
+ torch.float32: h100_peak_flops_float32,
+ torch.float16: h100_peak_flops_fp16_tc,
+ torch.bfloat16: h100_peak_flops_fp16_tc,
+ torch.float8_e4m3fn: h100_peak_tops_float8_tc,
+ torch.float8_e5m2: h100_peak_tops_float8_tc,
+}
+
+# prevent splitting columns when printing a data frame
+pd.set_option("display.expand_frame_repr", False)
+# print the entire data frame
+pd_print_full_ctx = pd.option_context(
+ "display.max_rows", None, "display.max_columns", None
+)
+
+
+def benchmark_torch_function_in_microseconds(
+ func: Callable,
+ *args,
+ **kwargs,
+) -> float:
+ t0 = benchmark.Timer(
+ stmt="func(*args, **kwargs)",
+ globals={"args": args, "kwargs": kwargs, "func": func},
+ )
+ return t0.blocked_autorange().median * 1e6
+
+
+@dataclass
+class Experiment:
+ name: str
+ shape: Tuple[int, int, int]
+ ref_time_sec: float
+ float8_time_sec: float
+ dtype: torch.dtype
+ compiled: bool
+ use_fast_accum: bool
+ scaling_repr: str
+
+ # 3 Times since we are calculating forward backward
+ @property
+ def ref_tops_sec(self):
+ M, K, N = self.shape
+ return float(3 * (2 * M * K * N)) / self.ref_time_sec
+
+ @property
+ def ref_pct_top_peak(self):
+ return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]
+
+ @property
+ def float8_tops_sec(self):
+ M, K, N = self.shape
+ return float(3 * (2 * M * K * N)) / self.float8_time_sec
+
+ @property
+ def float8_pct_top_peak(self):
+ return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]
+
+
+def main(
+ sweep_path: Optional[Path] = None,
+ compile: bool = True,
+ n_limit: Optional[int] = None,
+ fast_accum_filter: Optional[bool] = None,
+ shape_name_filter: Optional[str] = None,
+ scaling_type_input: str = "dynamic",
+ scaling_type_weight: str = "dynamic",
+ scaling_type_grad_output: str = "dynamic",
+):
+ device = "cuda"
+ print(f"Compile is set to | {compile}")
+
+ scaling_type_input = ScalingType(scaling_type_input)
+ scaling_type_weight = ScalingType(scaling_type_weight)
+ scaling_type_grad_output = ScalingType(scaling_type_grad_output)
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ )
+
+ # LLaMa 2 70B single-node weight shapes
+ # assumes fused attn.wqkv and ffn.w13
+ name_to_shapes_70b = {
+ "attn.wqkv": (8192, 1280),
+ "attn.w0": (1024, 8192),
+ "ffn.w13": (8192, 7168),
+ "ffn.w2": (3584, 8192),
+ }
+ input_bias = False
+ if fast_accum_filter is not None:
+ use_fast_accum = [fast_accum_filter]
+ else:
+ use_fast_accum = [True, False]
+ if shape_name_filter is not None:
+ k = shape_name_filter
+ name_to_shapes_70b = {k: name_to_shapes_70b[k]}
+ experiment_list: List[Experiment] = []
+ dtype = torch.bfloat16
+ for idx, (fast_accum, (name, (K, N))) in enumerate(
+ tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
+ ):
+ if n_limit is not None and idx >= n_limit:
+ break
+ linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
+ device=device, dtype=dtype
+ )
+
+ linear_float8 = Float8Linear.from_float(
+ copy.deepcopy(linear_ref),
+ config=config,
+ )
+ scaling_repr = linear_float8.scaling_repr()
+
+ if fast_accum:
+ linear_float8.forward_config = ScaledMMConfig(False, True, False)
+ else:
+ linear_float8.forward_config = ScaledMMConfig(False, False, False)
+
+ bsz, seq_len = 4, 4096
+ M = bsz * seq_len
+ input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
+ ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
+
+ def float8_forw_backward():
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(linear_float8)
+ linear_float8(input_tensor).sum().backward()
+
+ def n_times(n, fn, *args, **kwargs):
+ def wrapper(*args, **kwargs):
+ for _ in range(n):
+ fn(*args, **kwargs)
+
+ return wrapper
+
+ REPEAT_N = 100
+
+ ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
+ float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
+
+ if compile:
+ ref_forw_backward = torch.compile(ref_forw_backward)
+ float8_forw_backward = torch.compile(float8_forw_backward)
+
+ for _ in range(5):
+ ref_forw_backward()
+ float8_forw_backward()
+
+ ref_time = (
+ benchmark_torch_function_in_microseconds(ref_forw_backward)
+ * 1e-6
+ / REPEAT_N
+ )
+ float8_time = (
+ benchmark_torch_function_in_microseconds(float8_forw_backward)
+ * 1e-6
+ / REPEAT_N
+ )
+ experiment = Experiment(
+ name,
+ (M, K, N),
+ ref_time,
+ float8_time,
+ dtype,
+ compile,
+ use_fast_accum=fast_accum,
+ scaling_repr=scaling_repr,
+ )
+ print(experiment)
+ print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
+ experiment_list.append(experiment)
+ torch._dynamo.reset()
+
+ headers = [
+ "name",
+ "M",
+ "K",
+ "N",
+ "scaling_repr",
+ "ref_dtype",
+ "compiled",
+ "use_fast_accum",
+ "ref_time_sec",
+ "pt_fp8_time_sec",
+ "ref_tops_sec",
+ "ref_pct_top_peak",
+ "pt_fp8_tops_sec",
+ "pt_fp8_pct_top_peak",
+ ]
+ data = []
+ for experiment in experiment_list:
+ data.append(
+ [
+ experiment.name,
+ experiment.shape[0],
+ experiment.shape[1],
+ experiment.shape[2],
+ experiment.scaling_repr,
+ experiment.dtype,
+ experiment.compiled,
+ experiment.use_fast_accum,
+ experiment.ref_time_sec,
+ experiment.float8_time_sec,
+ experiment.ref_tops_sec,
+ experiment.ref_pct_top_peak,
+ experiment.float8_tops_sec,
+ experiment.float8_pct_top_peak,
+ ]
+ )
+
+ data_pd = pd.DataFrame(data, columns=headers)
+ data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
+ data_pd["shape"] = (
+ "("
+ + data_pd["M"].astype(str)
+ + ", "
+ + data_pd["K"].astype(str)
+ + ", "
+ + data_pd["N"].astype(str)
+ + ")"
+ )
+
+ data_pd_simple = data_pd[
+ [
+ "name",
+ "shape",
+ "scaling_repr",
+ "compiled",
+ "use_fast_accum",
+ "ref_time_sec",
+ "pt_fp8_time_sec",
+ "pt_fp8_speedup",
+ ]
+ ]
+ with pd_print_full_ctx:
+ print(data_pd_simple)
+
+ if sweep_path is not None:
+ sweep_path = sweep_path.with_suffix(".csv")
+ data_pd.to_csv(sweep_path)
+
+
+def invoke_main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-o", "--output_path", type=str, required=False)
+ parser.add_argument("--disable_compile", action="store_true")
+ parser.add_argument("-n", "--n_limit", type=int, required=False)
+ parser.add_argument("--fast_accum_filter", type=bool, required=False)
+ parser.add_argument("--shape_name_filter", type=str, required=False)
+ parser.add_argument("--scaling_type_input", type=str, required=False)
+ parser.add_argument("--scaling_type_weight", type=str, required=False)
+ parser.add_argument("--scaling_type_grad_output", type=str, required=False)
+ args = parser.parse_args()
+ output_path = Path(args.output_path) if args.output_path is not None else None
+ kwargs = {}
+ if args.scaling_type_input is not None:
+ kwargs["scaling_type_input"] = args.scaling_type_input
+ if args.scaling_type_weight is not None:
+ kwargs["scaling_type_weight"] = args.scaling_type_weight
+ if args.scaling_type_grad_output is not None:
+ kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
+ main(
+ output_path,
+ not args.disable_compile,
+ args.n_limit,
+ args.fast_accum_filter,
+ args.shape_name_filter,
+ **kwargs,
+ )
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py
new file mode 100644
index 000000000..6220670ee
--- /dev/null
+++ b/benchmarks/float8/bench_matmul.py
@@ -0,0 +1,139 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import itertools
+from typing import Optional
+
+import fire
+import pandas as pd
+
+import torch
+import torch.nn as nn
+import torch.utils.benchmark as benchmark
+
+# estimating TOPs for matmuls in fp32, fp16, fp8
+# assuming A * B = C, with A being M * K, B being K * N, C being M * N
+
+# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
+h100_peak_flops_float32 = 67e12
+h100_peak_flops_fp16_tc = 989e12
+h100_peak_tops_float8_tc = 1979e12
+
+dtype_to_peak_tops = {
+ torch.float32: h100_peak_flops_float32,
+ torch.float16: h100_peak_flops_fp16_tc,
+ torch.bfloat16: h100_peak_flops_fp16_tc,
+ torch.float8_e4m3fn: h100_peak_tops_float8_tc,
+ torch.float8_e5m2: h100_peak_tops_float8_tc,
+}
+
+
+def benchmark_fn_in_sec(f, *args, **kwargs):
+ # Manual warmup
+ for _ in range(4):
+ f(*args, **kwargs)
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ measurement = t0.blocked_autorange()
+ return measurement.mean
+
+
+def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
+ time_sec = benchmark_fn_in_sec(f, *args, **kwargs)
+ tops_sec = float(tops) / time_sec
+ pct_top_peak = tops_sec / peak_tops
+ return time_sec, tops_sec, pct_top_peak
+
+
+@torch.inference_mode()
+def run(n_limit: Optional[int] = None):
+ device = "cuda"
+
+ # LLaMa 2 70B single-node weight shapes
+ # assumes fused attn.wqkv and ffn.w13
+ # source: https://fburl.com/gsheet/g8onr7rh
+ name_to_shapes_70b = {
+ "attn.wqkv": (8192, 1280),
+ "attn.w0": (1024, 8192),
+ "ffn.w13": (8192, 7168),
+ "ffn.w2": (3584, 8192),
+ }
+
+ headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
+ results = []
+
+ name_to_shapes = name_to_shapes_70b
+ dtypes = torch.bfloat16, torch.float16
+
+ for idx, (dtype, (name, (K, N))) in enumerate(
+ itertools.product(dtypes, name_to_shapes.items())
+ ):
+ if n_limit is not None and idx >= n_limit:
+ break
+
+ # source: Xiao Sun, these are realistic for LLaMa 70B training
+ bsz, seq_len = 4, 4096
+
+ M = bsz * seq_len
+ print("M, K, N:", M, K, N)
+ tops = 2 * M * N * K
+ print(f"tops: {tops:.2E}")
+
+ # raw torch.mm
+ A = torch.randn(M, K, device=device, dtype=dtype)
+ m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False))
+ ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
+ tops, dtype_to_peak_tops[dtype], m_ref, A
+ )
+ print(
+ f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}"
+ )
+
+ del A
+
+ # raw float8 matmul (upper bound for what we can achive in eager mode)
+ # TODO(future): add e5m2
+ d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
+ A = torch.zeros(M, K, device=device, dtype=d1)
+ B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
+
+ def do_matmul(A, B):
+ scale_a = torch.tensor([1.0], device=device)
+ scale_b = torch.tensor([1.0], device=device)
+ return torch._scaled_mm(
+ A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
+ )
+
+ fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
+ tops, dtype_to_peak_tops[d1], do_matmul, A, B
+ )
+ print(
+ f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
+ )
+
+ del A, B
+
+ results.append(
+ [
+ name,
+ (M, K, N),
+ dtype,
+ ref_time_sec,
+ fp8_time_sec,
+ ref_time_sec / fp8_time_sec,
+ ]
+ )
+
+ data_pd = pd.DataFrame(results, columns=headers)
+ print(data_pd)
+
+
+def main() -> None:
+ fire.Fire(run)
+
+
+if __name__ == "__main__":
+ main() # pragma: no cover
diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py
new file mode 100644
index 000000000..44c758d1b
--- /dev/null
+++ b/benchmarks/float8/bench_multi_gpu.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable
+
+import fire
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.utils.benchmark as benchmark
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ sync_float8_amax_and_scale_history,
+)
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+
+torch.manual_seed(0)
+
+# TODO: Add more shapes for the benchmark
+B, M, K, N = 32, 1024, 1024, 1024
+lr = 0.01
+
+config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+)
+
+
+def benchmark_torch_function_in_microseconds(
+ func: Callable,
+ *args,
+ **kwargs,
+) -> float:
+ t0 = benchmark.Timer(
+ stmt="func(*args, **kwargs)",
+ globals={"args": args, "kwargs": kwargs, "func": func},
+ )
+ return t0.blocked_autorange().median * 1e6
+
+
+def setup(rank, world_size):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def cleanup():
+ dist.destroy_process_group()
+
+
+def get_model(K, N, is_fp8, base_dtype=torch.float32):
+ modules = [
+ nn.Linear(K, N, dtype=base_dtype),
+ nn.ReLU(),
+ ]
+ N_LAYERS = 20
+ # N linear layers
+ for _ in range(N_LAYERS - 1):
+ modules.append(nn.Linear(N, N, dtype=base_dtype))
+ modules.append(nn.ReLU())
+ m = nn.Sequential(*modules)
+ if is_fp8:
+ convert_to_float8_training(
+ m,
+ config=config,
+ )
+ return m
+
+
+def fsdp_main(rank, world_size, args):
+ setup(rank, world_size)
+ torch.cuda.set_device(rank)
+
+ base_dtype, input_global, compile = args
+
+ # basic distributed data sampling
+ assert B % world_size == 0
+ bsz_local_start = int(rank / world_size * B)
+ bsz_local_end = int((rank + 1) / world_size * B)
+ input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank)
+
+ fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank)
+ # Need use_orig_params=True to compile FSDP
+ fp8_model = FSDP(fp8_model, use_orig_params=True)
+ fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size)
+
+ # Run one iteration to make compile work, see experiments doc for more context of this issue.
+ fp8_optimizer.zero_grad()
+ y_local = fp8_model(input_tensor)
+ y_local.sum().backward()
+ fp8_optimizer.step()
+ sync_float8_amax_and_scale_history(fp8_model)
+
+ sync_float8_func = sync_float8_amax_and_scale_history
+ if compile:
+ # TODO: Need to fix issues with compile
+ fp8_model = torch.compile(fp8_model)
+ sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)
+
+ def float8_forw_backward():
+ fp8_optimizer.zero_grad()
+ y_local = fp8_model(input_tensor)
+ y_local.sum().backward()
+ fp8_optimizer.step()
+ sync_float8_func(fp8_model)
+
+ ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank)
+ ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size)
+ if compile:
+ ref_model = torch.compile(ref_model)
+
+ ref_model = FSDP(ref_model, use_orig_params=True)
+
+ def ref_forw_backward():
+ ref_optimizer.zero_grad()
+ ref_model(input_tensor).sum().backward()
+ ref_optimizer.step()
+
+ def run_n_iterations(n, fn):
+ for _ in range(n):
+ fn()
+ # make sure training is done on all ranks
+ dist.barrier()
+
+ # warmup
+ run_n_iterations(50, ref_forw_backward)
+ run_n_iterations(50, float8_forw_backward)
+
+ N_ITER = 50
+ ref_time = (
+ benchmark_torch_function_in_microseconds(
+ run_n_iterations, N_ITER, ref_forw_backward
+ )
+ * 1e-6
+ / N_ITER
+ )
+ float8_time = (
+ benchmark_torch_function_in_microseconds(
+ run_n_iterations, N_ITER, float8_forw_backward
+ )
+ * 1e-6
+ / N_ITER
+ )
+
+ if rank == 0:
+ print("ref_time", ref_time)
+ print("float8_time", float8_time)
+ print("float8 speedup", ref_time / float8_time)
+
+ cleanup()
+
+
+def run(compile: bool):
+ base_dtype = torch.bfloat16
+ WORLD_SIZE = torch.cuda.device_count()
+ print(f"{base_dtype = }")
+ print(f"{compile = }")
+ print(f"{WORLD_SIZE = }")
+
+ # generate input data
+ ref_input = torch.randn(B, M, K).cuda().to(base_dtype)
+ # run fsdp model
+ args = (base_dtype, ref_input, compile)
+ mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
+
+
+# Usgae:
+# CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py
+if __name__ == "__main__":
+ fire.Fire(run)
diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py
new file mode 100644
index 000000000..977755343
--- /dev/null
+++ b/benchmarks/float8/bench_padding.py
@@ -0,0 +1,223 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import fire
+
+import torch
+from torchao.float8.float8_tensor import (
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+ ScaledMMConfig,
+)
+from torchao.float8.float8_utils import pad_tensor_for_matmul
+from tabulate import tabulate
+from torch._inductor.utils import do_bench_using_profiling
+from tqdm import tqdm
+
+# estimating TOPs for matmuls in fp32, fp16, fp8
+# assuming A * B = C, with A being M * K, B being K * N, C being M * N
+
+# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
+h100_peak_flops_float32 = 67e12
+h100_peak_flops_fp16_tc = 1979e12
+h100_peak_tops_float8_tc = 3958e12
+
+dtype_to_peak_tops = {
+ torch.float32: h100_peak_flops_float32,
+ torch.float16: h100_peak_flops_fp16_tc,
+ torch.bfloat16: h100_peak_flops_fp16_tc,
+ torch.float8_e4m3fn: h100_peak_tops_float8_tc,
+ torch.float8_e5m2: h100_peak_tops_float8_tc,
+}
+
+
+def benchmark_fn_in_usec(f, *args, **kwargs):
+ no_args = lambda: f(*args, **kwargs)
+ time = do_bench_using_profiling(no_args)
+ return time * 1e3
+
+
+def get_tops_info(tops, time, peak_tops):
+ time_sec = time / 1e6
+ tops_sec = float(tops) / time_sec
+ pct_top_peak = tops_sec / peak_tops
+ return tops_sec, pct_top_peak
+
+
+def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
+ scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
+
+ a_config = ScaledMMConfig(
+ emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
+ )
+ b_config = ScaledMMConfig(
+ emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
+ )
+ a_config = LinearMMConfig(a_config, a_config, a_config)
+ b_config = LinearMMConfig(b_config, b_config, b_config)
+
+ a_fp8 = hp_tensor_and_scale_to_float8(
+ A,
+ scale_a,
+ fp8_dtype,
+ a_config,
+ GemmInputRole.INPUT,
+ )
+ b_fp8 = hp_tensor_and_scale_to_float8(
+ B,
+ scale_b,
+ fp8_dtype,
+ b_config,
+ GemmInputRole.WEIGHT,
+ )
+
+ return a_fp8 @ b_fp8
+
+
+def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
+ # Breaks with compile due to trying to pad on fp8 dtype
+ # return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
+ A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
+ B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
+
+ scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
+
+ A_pad = A_pad.to(fp8_dtype) # mem copy
+ B_pad = B_pad.to(fp8_dtype) # mem copy
+
+ B_pad = B_pad.t().contiguous().t() # mem copy
+
+ return torch._scaled_mm(
+ A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
+ )
+
+
+def do_hp_matmul(A, B):
+ return torch.matmul(A, B)
+
+
+def do_aligned_bf16_matmul(A, B):
+ A_pad = pad_tensor_for_matmul(A, dims=1)
+ B_pad = pad_tensor_for_matmul(B, dims=0)
+ return torch.matmul(A_pad, B_pad)
+
+
+@dataclass
+class Experiment_config:
+ M: int
+ K: int
+ N: int
+ output_dtype: torch.dtype
+ fp8_dtype: torch.dtype
+
+ def __iter__(self):
+ return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))
+
+
+def gen_configs():
+ shapes = shapes = [
+ (8193, 2501, 5008),
+ (65, 253, 4096),
+ (1023, 1029, 2512),
+ (4095, 511, 10000),
+ (2047, 3073, 8192),
+ (511, 769, 7504),
+ (127, 4097, 12288),
+ (32769, 15, 15024),
+ (9217, 8191, 20480),
+ (16385, 1025, 25008),
+ ]
+ output_dtype = torch.bfloat16
+ fp8_dtype = torch.float8_e4m3fn
+ return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
+
+
+@torch.no_grad()
+def run(compile: bool = False, n_limit: Optional[int] = None):
+ device = "cuda"
+ experiments = gen_configs()
+ results = []
+ tops_table = []
+ tops_headers = [
+ "Shape",
+ "Ref Dtype",
+ "Ref Tops",
+ "Aligned BF16 Tops",
+ "FP8 Tops",
+ "Ref % Peak",
+ "Aligned BF16 % Peak",
+ "FP8 % Peak",
+ ]
+
+ for experiment in tqdm(experiments):
+ M, K, N, output_dtype, fp8_dtype = experiment
+ tops = 2 * M * N * K
+
+ A_base = torch.rand(M, K, device=device, dtype=output_dtype)
+ B_base = torch.rand(K, N, device=device, dtype=output_dtype)
+
+ hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
+ aligned_bf16_func = (
+ torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul
+ )
+ fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul
+
+ ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
+ aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base)
+ fp8_time = benchmark_fn_in_usec(
+ fp8_func, A_base, B_base, fp8_dtype, output_dtype
+ )
+
+ ref_tops_sec, ref_pct_top_peak = get_tops_info(
+ tops, ref_time, dtype_to_peak_tops[output_dtype]
+ )
+ aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info(
+ tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16]
+ )
+ fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
+ tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
+ )
+ tops_table.append(
+ [
+ f"({M}x{K}x{N})",
+ f"{output_dtype}",
+ f"{ref_tops_sec:.2E}",
+ f"{aligned_bf16_tops_sec:.2E}",
+ f"{fp8_tops_sec:.2E}",
+ f"{ref_pct_top_peak:.3f}",
+ f"{aligned_bf16_pct_top_peak:.3f}",
+ f"{fp8_pct_top_peak:.3f}",
+ ]
+ )
+ results.append(
+ [
+ (M, K, N),
+ output_dtype,
+ ref_time,
+ aligned_bf16_time,
+ fp8_time,
+ ref_time / aligned_bf16_time,
+ ref_time / fp8_time,
+ ]
+ )
+
+ print("TOPs".center(80, "*"))
+ print(tabulate(tops_table, headers=tops_headers))
+ print("Speed Results".center(80, "*"))
+ headers = [
+ "Shape",
+ "Ref Dtype",
+ "Ref Time",
+ "Aligned BF16 Time",
+ "FP8 Time",
+ "Aligned BF16 Speedup",
+ "FP8 Speedup",
+ ]
+ print(tabulate(results, headers=headers, tablefmt="grid"))
+
+
+if __name__ == "__main__":
+ fire.Fire(run)
diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py
new file mode 100644
index 000000000..914759849
--- /dev/null
+++ b/benchmarks/float8/profile_linear_float8.py
@@ -0,0 +1,447 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+import io
+import random
+from contextlib import nullcontext, redirect_stdout
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Callable, Optional
+
+import fire
+import pandas as pd
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torch.profiler import profile, ProfilerActivity, record_function
+from utils import (
+ kernel_name_to_category,
+ parse_bw_and_kernel_name,
+ profiler_output_to_gpu_time_for_key,
+ profiler_output_to_time_by_kernel_name,
+)
+
+# don't truncate long kernel names
+pd.options.display.max_colwidth = 100
+# display 3 trailing decimal points for floats
+pd.set_option("display.float_format", "{:.3f}".format)
+
+
+class LNLinear(torch.nn.Module):
+ def __init__(self, fc_dim1, fc_dim2):
+ super().__init__()
+ self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False)
+ self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False)
+
+ def forward(self, x):
+ x = self.ln(x)
+ x = self.fc(x)
+ return x
+
+
+# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
+class RMSNorm(nn.Module):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x: torch.Tensor):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+ def reset_parameters(self):
+ torch.nn.init.ones_(self.weight) # type: ignore
+
+
+# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
+class FeedForward(nn.Module):
+ """
+ FeedForward module
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
+
+ Attributes:
+ w1 (Linear): Linear transformation for the first layer.
+ w2 (Linear): Linear transformation for the second layer.
+ w3 (Linear): Linear transformation for the third layer.
+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+ def init_weights(self, init_std: float):
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
+ for linear in (self.w2, self.w3):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
+
+
+class NormFFNResidualNorm(nn.Module):
+ """
+ A fragment representing the end of TransformerBlock n and the start
+ of TransformerBlock n + 1, intended to include the fusions relevant
+ to float8 gemms in the FFN module in forward and backward.
+ """
+
+ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier):
+ super().__init__()
+ self.ffn_norm = RMSNorm(dim)
+ self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)
+ self.attn_norm = RMSNorm(dim)
+
+ def forward(self, h):
+ # end of transformer block n
+ x = self.ffn_norm(h)
+ x = self.ffn(x)
+ x = h + x
+ # start of transformer block n + 1
+ x = self.attn_norm(x)
+ return x
+
+
+@dataclass
+class ProfileConfig:
+ file_path: Optional[str] = None
+ name: Optional[str] = None
+ cuda: bool = True
+ iters: int = 0
+ warmup_iters: int = 0
+ sync: bool = False
+ extra_kwargs: dict = field(default_factory=dict)
+ memory_profile_path: Optional[str] = None
+
+
+def profile_function(
+ config: ProfileConfig, func: Callable, *args, **kwargs
+) -> torch.profiler.profile:
+ """Profile a torch function and save the result to a file"""
+ seed = 123
+ random.seed(seed)
+ torch.manual_seed(seed)
+
+ activities = [ProfilerActivity.CPU]
+ if config.cuda:
+ activities.append(ProfilerActivity.CUDA)
+
+ if config.warmup_iters >= 0:
+ for _ in range(config.warmup_iters):
+ func(*args, **kwargs)
+ if config.sync:
+ torch.cuda.synchronize()
+ name_context = (
+ nullcontext() if config.name is None else record_function(config.name)
+ )
+ profile_memory = config.memory_profile_path is not None
+ with profile(
+ activities=activities,
+ profile_memory=profile_memory,
+ record_shapes=profile_memory,
+ with_stack=profile_memory,
+ **config.extra_kwargs,
+ ) as prof:
+ for _ in range(config.iters):
+ with name_context:
+ func(*args, **kwargs)
+ if config.sync:
+ torch.cuda.synchronize()
+
+ if config.file_path is not None:
+ prof.export_chrome_trace(config.file_path)
+
+ return prof
+
+
+def main(
+ profile_path_prefix: Path,
+ compile: bool = True,
+ scaling_type_input: str = "dynamic",
+ scaling_type_weight: str = "dynamic",
+ scaling_type_grad_output: str = "dynamic",
+ model_type: str = "linear",
+ dtype_filter: str = "both",
+):
+ assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
+ assert dtype_filter in ("both", "float8", "bfloat16")
+
+ scaling_type_input = ScalingType(scaling_type_input)
+ scaling_type_weight = ScalingType(scaling_type_weight)
+ scaling_type_grad_output = ScalingType(scaling_type_grad_output)
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ )
+ scaling_repr = "_".join(
+ [
+ s.short_str()
+ for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
+ ]
+ )
+
+ print(f"Compile is set to | {compile}")
+ print(f"model_type is set to | {model_type}")
+ print(f"scaling_repr is set to | {scaling_repr}")
+
+ device = "cuda"
+ ref_dtype = torch.bfloat16
+ if model_type == "ln_linear":
+ M, K, N = 4 * 4096, 8192, 7168
+ m_ref = LNLinear(K, N)
+ input_tensor = torch.randn(
+ M, K, device=device, dtype=ref_dtype, requires_grad=True
+ )
+ elif model_type == "norm_ffn_norm":
+ m_ref = NormFFNResidualNorm(
+ dim=4096,
+ hidden_dim=16384,
+ multiple_of=1024,
+ ffn_dim_multiplier=1.3,
+ )
+ input_tensor = torch.randn(
+ 1, 8192, 4096, device=device, dtype=ref_dtype
+ ).requires_grad_()
+ else:
+ M, K, N = 4 * 4096, 8192, 7168
+ m_ref = torch.nn.Sequential(
+ torch.nn.Linear(K, N, bias=False),
+ )
+ input_tensor = torch.randn(
+ M, K, device=device, dtype=ref_dtype, requires_grad=True
+ )
+
+ m_ref = m_ref.to(device).to(ref_dtype)
+
+ m_float8 = copy.deepcopy(m_ref)
+ convert_to_float8_training(m_float8, config=config)
+
+ def ref_forw_backward(x):
+ out = m_ref(x)
+ out.sum().backward()
+
+ def float8_forw(x):
+ out = m_float8(x)
+ return out
+
+ sync_amax_history = sync_float8_amax_and_scale_history
+
+ def float8_forw_backward_wrapper(x):
+ # sync_float8_amax_and_scale_history is not full graph torch
+ # compile friendly, so we add a high level wrapper to allow
+ # inspection of the fw+bw torch.compile without the scale
+ # syncing code
+ # TODO(future): make this better
+ if linear_requires_sync(config):
+ with record_function("scale_amax_and_scales"):
+ sync_amax_history(m_float8)
+ out = float8_forw(x)
+
+ # out.sum().backward() is also not torch.compile fullgraph
+ # friendly
+ with record_function("backward"):
+ out.sum().backward()
+
+ if compile:
+ m_ref = torch.compile(m_ref, fullgraph=True)
+ float8_forw = torch.compile(float8_forw, fullgraph=True)
+ # Note: it's faster to compile the combination of sync_amax_history wit
+ # forward because we only look up from dynamo cache once.
+ # However, compiling the sync function separately makes it more
+ # convenient to analyze the total time spent on it.
+ sync_amax_history = torch.compile(sync_amax_history)
+
+ # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
+ # to populate triton kernel bandwidth further down in the script
+ f = io.StringIO()
+ with redirect_stdout(f):
+ # warm up
+ for _ in range(1):
+ if dtype_filter != "float8":
+ ref_forw_backward(input_tensor)
+ if dtype_filter != "bfloat16":
+ float8_forw_backward_wrapper(input_tensor)
+
+ profile_iters = 5
+ ref_times, float8_times = None, None
+ data = []
+
+ if dtype_filter != "float8":
+ # Profile Reference Model
+ print("profiling ref")
+ ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
+ ref_path = profile_path_prefix + ref_suffix
+ profile_config = ProfileConfig(
+ ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
+ )
+ p = profile_function(profile_config, ref_forw_backward, input_tensor)
+ print(f"saved {ref_path}")
+ ref_times = profiler_output_to_time_by_kernel_name(p)
+ total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
+ for k, v in ref_times.items():
+ v_ms = v / 1e3 / profile_iters
+ data.append(
+ [
+ "0_ref",
+ k,
+ kernel_name_to_category(k),
+ v_ms,
+ v_ms / total_time_ms,
+ None,
+ ]
+ )
+
+ if dtype_filter != "bfloat16":
+ # Profile Float8 Model
+ print("profiling float8")
+ float8_suffix = (
+ f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
+ )
+ float8_path = profile_path_prefix + float8_suffix
+ profile_config = ProfileConfig(
+ float8_path,
+ float8_suffix,
+ iters=profile_iters,
+ warmup_iters=2,
+ sync=True,
+ )
+ p = profile_function(
+ profile_config, float8_forw_backward_wrapper, input_tensor
+ )
+ print(f"saved {float8_path}")
+ float8_times = profiler_output_to_time_by_kernel_name(p)
+ total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
+ for k, v in float8_times.items():
+ v_ms = v / 1e3 / profile_iters
+ data.append(
+ [
+ "1_float8",
+ k,
+ kernel_name_to_category(k),
+ v / 1e3 / profile_iters,
+ v_ms / total_time_ms,
+ None,
+ ]
+ )
+
+ # get the time spent per user annotation
+ sync_time_us = profiler_output_to_gpu_time_for_key(
+ p, "scale_amax_and_scales"
+ )
+ sync_time_ms = sync_time_us / profile_iters / 1e3
+ print(f"Sync time ms: {sync_time_ms}")
+
+ # print the redirected stdout back to regular stdout
+ print(f.getvalue())
+
+ # populate the triton kernel bandwidth
+ for line in f.getvalue().split("\n"):
+ maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line)
+ if maybe_kernel_name is not None:
+ # O(N) search, but it's ok since lists are small
+ for datum in data:
+ if datum[1] == maybe_kernel_name:
+ datum[-1] = maybe_bw
+
+ df = pd.DataFrame(
+ data,
+ columns=[
+ "experiment",
+ "kernel",
+ "category",
+ "time_ms",
+ "pct_gpu_time",
+ "bw_gpbs",
+ ],
+ )
+ df.sort_values(
+ ["experiment", "category", "pct_gpu_time"],
+ ascending=[True, True, False],
+ inplace=True,
+ )
+ print("\nSummary of GPU time by CPU kernel\n\n", df)
+
+ # compare gemm and overhead time
+ df_p = df.pivot_table(
+ columns=["category"],
+ index="experiment",
+ values="time_ms",
+ aggfunc="sum",
+ fill_value=0,
+ margins=True,
+ )
+ # drop last row, which has totals across ref + float8 which does not make sense
+ df_p = df_p[:-1]
+ df_p = df_p.transpose()
+
+ if dtype_filter == "both":
+ df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"]
+ df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"]
+
+ # calculate sync time as pct of total float time
+ # note: this time is not useful if TORCHINDUCTOR_PROFILE is on
+ total_float8_ms = df_p.iloc[3]["1_float8"]
+ sync_approx_ratio = sync_time_ms / total_float8_ms
+ print(
+ f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}"
+ )
+
+ print("\nSummary of time (ms) by kernel category\n\n", df_p)
+
+
+def invoke_main() -> None:
+ # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic"
+ # You can set TORCHINDUCTOR_PROFILE=1 to also capture triton kernel bandwidth
+ fire.Fire(main)
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py
new file mode 100644
index 000000000..aec19e2cd
--- /dev/null
+++ b/benchmarks/float8/utils.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+import collections
+import re
+
+
+def profiler_output_to_time_by_kernel_name(prof):
+ """
+ Input: a profiler with captured events.
+ Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name
+
+ Note that if there are user_annotations in the captured events, `torch.profiler`
+ will include their time in the total GPU time displayed at the bottom of
+ `key_averages.table()`. The filter below excludes them to prevent double
+ counting.
+ """
+ key_averages = prof.key_averages()
+ thresh = 1e-10
+ kernel_name_to_gpu_time_us = collections.defaultdict(float)
+ for e in key_averages:
+ # manually filter top-level CPU events with attributed CUDA time
+ # example CPU event row:
+ # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1
+ # and it maps to this CUDA event:
+ # sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1
+ if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh):
+ continue
+ kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
+ return kernel_name_to_gpu_time_us
+
+
+def profiler_output_to_gpu_time_for_key(prof, key):
+ """
+ Input: an event name
+ Output: sum of GPU time of all events with that name in `prof`
+
+ This is useful to get the total time of a user annotation
+ """
+ total = 0
+ for e in prof.profiler.function_events:
+ if e.key == key:
+ total += e.device_time_total
+ return total
+
+
+def kernel_name_to_category(k):
+ # number prefix is for easy sorting
+ if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"):
+ return "0_gemm"
+ elif (
+ # max(abs(tensor))
+ ("abs" in k and "max" in k)
+ or
+ # casting pointwise to float8
+ ("clamp" in k)
+ or
+ # things related to scaled_mm
+ ("scaled_mm" in k)
+ or
+ # syncing amaxes and scales
+ ("roll" in k)
+ ):
+ # note: the above filter is approximate and will give false
+ # positives if model code contains other code to abs/max/clamp
+ return "1_f8_overhead"
+ return "2_other"
+
+
+def parse_bw_and_kernel_name(line):
+ """
+ Input: a single line of stdout of TORCHINDUCTOR_PROFILE=1 output, such as
+ 0.257ms 0.537 GB 2092.43GB/s triton_red_fused_native_layer_norm_0
+ Output: the bandwidth value and the kernel name, or None and None
+ """
+ result = re.search(".* ([0-9\.]+)GB/s.*(triton_[a-z_0-9]+)", line)
+ if result:
+ return result.group(1), result.group(2)
+ else:
+ return None, None
diff --git a/test/float8/test_base.py b/test/float8/test_base.py
new file mode 100644
index 000000000..0780968aa
--- /dev/null
+++ b/test/float8/test_base.py
@@ -0,0 +1,723 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import copy
+import io
+import itertools
+import random
+import re
+import unittest
+import warnings
+
+import pytest
+
+import torch
+import torch.nn as nn
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear import Float8Linear
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_python_api import addmm_float8_unwrapped
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+ ScaledMMConfig,
+)
+from torchao.float8.float8_utils import (
+ compute_error,
+ e4m3_dtype,
+ e5m2_dtype,
+ fp8_tensor_statistics,
+ FP8_TYPES,
+ tensor_to_scale,
+)
+from torchao.float8.inference import (
+ ActivationCasting,
+ QuantConfig,
+ quantize_to_float8,
+)
+
+random.seed(0)
+torch.manual_seed(0)
+
+is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
+
+
+def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
+ assert torch.all(a._data == b._data).item(), "scales are not identical"
+ assert torch.all(a._data == b._data).item(), "data is not identical"
+ return True
+
+
+class TestFloat8Tensor(unittest.TestCase):
+ def test_preserves_dtype(self) -> None:
+ # hp means high precision, lp means low precision
+ hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
+ lp_dtypes = FP8_TYPES
+ for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
+ x1_hp = torch.randn(4, 4, dtype=hp_dtype)
+ x1_s = tensor_to_scale(x1_hp, lp_dtype)
+ x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
+ x3_hp = x2_lp.to_original_precision()
+ self.assertTrue(x3_hp.dtype == hp_dtype)
+
+ def test_differentiable_casts(self) -> None:
+ lp_dtypes = (e4m3_dtype, e5m2_dtype)
+ for f8_dtype in lp_dtypes:
+ x = torch.randn(1).requires_grad_()
+ grad = torch.randn(1)
+ x_s = tensor_to_scale(x, f8_dtype)
+ x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
+ x_f8_hp = x_f8.to_original_precision()
+ x_f8_hp.backward(grad)
+ # the gradient should be unchanged through both casts
+ torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)
+
+ def test_split_cat(self):
+ a = torch.rand(16, 16, dtype=torch.bfloat16)
+ scale = tensor_to_scale(a, e4m3_dtype)
+ fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)
+
+ splits = torch.split(fp8_a, 16)
+ catted = torch.cat(splits, dim=0)
+ assert bitwise_identical(fp8_a, catted)
+
+ def test_index_put(self):
+ a = torch.rand(16, dtype=torch.bfloat16)
+ scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
+ fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
+
+ index = torch.randint(0, 15, (16,), dtype=torch.long)
+
+ b = torch.rand(16, 16, dtype=torch.bfloat16)
+ scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
+ fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
+ fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
+
+ with self.assertRaises(AssertionError):
+ b[index] = fp8_a
+ fp8_b[index] = a
+ fp8_b_bad[index] = fp8_a
+ fp8_b[index] = fp8_a
+
+ def test_copy_(self):
+ a = torch.rand(16, dtype=torch.bfloat16)
+ scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
+ fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
+
+ b = torch.empty(16, dtype=torch.bfloat16)
+ b.copy_(fp8_a) # Should work
+ torch.testing.assert_close(b, fp8_a.to_original_precision())
+ with self.assertRaises(RuntimeError):
+ fp8_a.copy_(b) # Should fail
+
+ fp8_b = Float8Tensor(
+ torch.empty(16, dtype=torch.float8_e4m3fn),
+ scale_a,
+ torch.bfloat16,
+ fp8_a._linear_mm_config,
+ )
+ fp8_b.copy_(fp8_a)
+ torch.testing.assert_close(fp8_a._data, fp8_b._data)
+
+ def test_weights_only_load(self):
+ module = nn.Linear(16, 16)
+ # Save model state dict
+ buffer = io.BytesIO()
+ fp8_module = quantize_to_float8(
+ module,
+ QuantConfig(
+ ActivationCasting.DYNAMIC,
+ ),
+ )
+
+ torch.save(fp8_module.state_dict(), buffer)
+ buffer.seek(0)
+ _ = torch.load(buffer, weights_only=True)
+
+
+class TestFloat8Linear:
+ def _test_linear_impl(
+ self,
+ x,
+ m_ref,
+ config: Float8LinearConfig,
+ ):
+ m_fp8 = Float8Linear.from_float(
+ copy.deepcopy(m_ref),
+ config,
+ )
+ for _ in range(2):
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m_fp8)
+ y_fp8 = m_fp8(x)
+ y_fp8.sum().backward()
+ y_ref = m_ref(x)
+ y_ref.sum().backward()
+
+ assert y_ref.shape == y_fp8.shape
+
+ y_sqnr = compute_error(y_ref, y_fp8)
+ g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad)
+ # verify sqnr is reasonable
+ assert y_sqnr >= 18.0, f"{y_sqnr} is too low"
+ assert g_sqnr >= 17.0, f"{g_sqnr} is too low"
+ if m_ref.bias is not None:
+ torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad)
+
+ # verify all of the amax buffers got updated
+ if linear_requires_sync(config):
+ # only check buffers that are actually used, based on per-tensor
+ # scaling settings
+ amax_buffer_names = []
+ amax_history_buffer_names = []
+ scale_buffer_names = []
+ if config.cast_config_input.scaling_type is ScalingType.DELAYED:
+ amax_buffer_names.append("fp8_amax_input")
+ amax_history_buffer_names.append("fp8_amax_history_input")
+ scale_buffer_names.append("fp8_scale_input")
+ if config.cast_config_weight.scaling_type is ScalingType.DELAYED:
+ amax_buffer_names.append("fp8_amax_weight")
+ amax_history_buffer_names.append("fp8_amax_history_weight")
+ scale_buffer_names.append("fp8_scale_weight")
+ if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED:
+ amax_buffer_names.append("fp8_amax_grad_output")
+ amax_history_buffer_names.append("fp8_amax_history_grad_output")
+ scale_buffer_names.append("fp8_scale_grad_output")
+
+ # verify all of the amax buffers got updated
+ max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES}
+ for buffer_name in amax_buffer_names:
+ buffer_value = getattr(m_fp8, buffer_name)
+ for init_val in max_float8_pos:
+ assert torch.ne(
+ buffer_value, torch.tensor(init_val)
+ ), f"{buffer_name} not filled, current value {buffer_value}"
+
+ # verify all of the amax history buffers got updated
+ for buffer_name in amax_history_buffer_names:
+ buffer_value = getattr(m_fp8, buffer_name)
+ assert torch.max(buffer_value) > 0.0, f"{buffer_name} not filled"
+
+ # verify all of the scale buffers got updated
+ for buffer_name in scale_buffer_names:
+ buffer_value = getattr(m_fp8, buffer_name)
+ assert torch.ne(
+ buffer_value, torch.tensor(1.0)
+ ), f"{buffer_name} not filled, current value {buffer_value}"
+
+ # verify initialization flags got updated
+ assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
+
+ @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
+ @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
+ @pytest.mark.parametrize(
+ "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+ )
+ @pytest.mark.parametrize(
+ "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+ )
+ @pytest.mark.parametrize(
+ "scaling_type_grad_output",
+ [ScalingType.DELAYED, ScalingType.DYNAMIC],
+ )
+ @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
+ @pytest.mark.parametrize("linear_bias", [False, True])
+ @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+ def test_linear(
+ self,
+ x_shape,
+ emulate: bool,
+ scaling_type_input: ScalingType,
+ scaling_type_weight: ScalingType,
+ scaling_type_grad_output: ScalingType,
+ linear_dtype: torch.dtype,
+ linear_bias: bool,
+ ):
+ if not emulate:
+ if not torch.cuda.is_available():
+ warnings.warn("CUDA not available")
+ pytest.skip()
+ elif torch.cuda.get_device_capability() < (9, 0):
+ warnings.warn(
+ f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
+ )
+ pytest.skip()
+ x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
+ m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ emulate=emulate,
+ )
+ self._test_linear_impl(
+ x,
+ m_ref,
+ config,
+ )
+
+ @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
+ @pytest.mark.parametrize(
+ "linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
+ )
+ @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+ def test_autocast_outputs(
+ self,
+ emulate: bool,
+ linear_dtype: torch.dtype,
+ ):
+ if not emulate:
+ if not torch.cuda.is_available():
+ warnings.warn("CUDA not available")
+ pytest.skip()
+ elif torch.cuda.get_device_capability() < (9, 0):
+ warnings.warn(
+ f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
+ )
+ pytest.skip()
+
+ m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+ emulate=emulate,
+ )
+ m = Float8Linear.from_float(copy.deepcopy(m_ref), config)
+
+ # autocast off
+ x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"
+
+ # autocast on
+ with torch.autocast("cuda"):
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"
+
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert (
+ y.dtype == torch.bfloat16
+ ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
+
+ @pytest.mark.parametrize(
+ "linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
+ )
+ @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
+ @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
+ emulate = (
+ not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
+ )
+
+ m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
+ config = Float8LinearConfig(emulate=emulate)
+ m = Float8Linear.from_float(copy.deepcopy(m), config)
+
+ # Cast the module to dtype
+ m = m.to(dtype=linear_dtype)
+ if linear_requires_sync(config):
+ # Check amax buffer types
+ for key in [
+ "fp8_amax_input",
+ "fp8_amax_history_input",
+ "fp8_scale_input",
+ "fp8_amax_weight",
+ "fp8_amax_history_weight",
+ "fp8_scale_weight",
+ "fp8_amax_grad_output",
+ "fp8_amax_history_grad_output",
+ "fp8_scale_grad_output",
+ ]:
+ assert (
+ m._buffers[key].dtype == torch.float32
+ ), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32"
+
+ # autocast off
+ x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"
+
+ # autocast on
+ with torch.autocast("cuda"):
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"
+
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(m)
+ y = m(x)
+ assert (
+ y.dtype == torch.bfloat16
+ ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
+
+ def test_repr(self):
+ m = nn.Linear(32, 16)
+ config = Float8LinearConfig(
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ emulate=True,
+ )
+ m = Float8Linear.from_float(
+ copy.deepcopy(m),
+ config=config,
+ )
+ s = m.__repr__()
+ assert "i:dyn,w:del,go:dyn" in s
+
+
+class TestScaledMM:
+ @unittest.skipIf(
+ not is_H100,
+ "CUDA not available",
+ )
+ @pytest.mark.parametrize(
+ "base_dtype", [torch.float16, torch.bfloat16, torch.float32]
+ )
+ @pytest.mark.parametrize("use_fast_accum", [True, False])
+ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
+ torch.manual_seed(42)
+ input_dtype = e4m3_dtype
+ output_dtype = base_dtype
+ compare_type = torch.float32
+
+ a = torch.randn(16, 16, device="cuda", dtype=base_dtype)
+ b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
+
+ a_scale = tensor_to_scale(a, input_dtype).float()
+ b_scale = tensor_to_scale(b, input_dtype).float()
+
+ a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
+ b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)
+
+ out_scaled_mm = addmm_float8_unwrapped(
+ a_fp8._data,
+ a_fp8._scale,
+ b_fp8._data,
+ b_fp8._scale,
+ output_dtype=output_dtype,
+ use_fast_accum=use_fast_accum,
+ )
+ out_emulated = torch.ops.aten.mm_float8_emulated(
+ a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
+ )
+
+ if output_dtype != base_dtype:
+ out_scaled_mm = out_scaled_mm.to(compare_type)
+ out_emulated = out_emulated.to(compare_type)
+
+ if base_dtype in {torch.bfloat16, torch.float16}:
+ atol, rtol = 7e-2, 7e-2
+ else:
+ atol, rtol = 2e-3, 2e-3
+ torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
+
+ @unittest.skipIf(not is_H100, "CUDA not available")
+ def test_different_configs_error(self):
+ x_fp32 = torch.randn(16, 16, device="cuda")
+ x_scale = torch.tensor(1.0, device="cuda")
+ fp8_dtype = e4m3_dtype
+ linear_config_a = LinearMMConfig(
+ ScaledMMConfig(False, True, False, False),
+ ScaledMMConfig(False, False, False, False),
+ ScaledMMConfig(False, False, False, False),
+ )
+ linear_config_b = LinearMMConfig(
+ ScaledMMConfig(True, True, False, False),
+ ScaledMMConfig(True, False, False, False),
+ ScaledMMConfig(True, False, False, False),
+ )
+ a = hp_tensor_and_scale_to_float8(
+ x_fp32,
+ x_scale,
+ fp8_dtype,
+ linear_config_a,
+ GemmInputRole.INPUT,
+ )
+ b = hp_tensor_and_scale_to_float8(
+ x_fp32,
+ x_scale,
+ fp8_dtype,
+ linear_config_b,
+ GemmInputRole.WEIGHT,
+ )
+ with pytest.raises(
+ AssertionError,
+ match="linear_mm_config.output mismatch",
+ ):
+ a @ b
+
+ @unittest.skipIf(
+ not is_H100,
+ "CUDA not available",
+ )
+ @pytest.mark.parametrize(
+ "base_dtype", [torch.float16, torch.bfloat16, torch.float32]
+ )
+ @pytest.mark.parametrize("use_fast_accum", [True, False])
+ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
+ torch.manual_seed(42)
+ input_dtype = torch.float8_e4m3fn
+ compare_type = torch.float32
+
+ a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
+ b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
+
+ a_scale = tensor_to_scale(a, input_dtype).float()
+ b_scale = tensor_to_scale(b, input_dtype).float()
+
+ a_fp8 = hp_tensor_and_scale_to_float8(
+ a, a_scale, input_dtype, None, GemmInputRole.INPUT
+ )
+ b_fp8 = hp_tensor_and_scale_to_float8(
+ b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
+ )
+
+ with pytest.raises(
+ RuntimeError,
+ match=re.escape(
+ "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
+ ),
+ ):
+ a_fp8 @ b_fp8
+
+ scaled_mm_config = ScaledMMConfig(False, use_fast_accum, False, True)
+ pad_config = LinearMMConfig(
+ scaled_mm_config, scaled_mm_config, scaled_mm_config
+ )
+
+ a_fp8 = hp_tensor_and_scale_to_float8(
+ a,
+ a_scale,
+ input_dtype,
+ pad_config,
+ GemmInputRole.INPUT,
+ )
+ b_fp8 = hp_tensor_and_scale_to_float8(
+ b,
+ b_scale,
+ input_dtype,
+ pad_config,
+ GemmInputRole.WEIGHT,
+ )
+ out_padded = a_fp8 @ b_fp8
+ out_padded.to(compare_type)
+
+ emulated_scaled_mm_config = ScaledMMConfig(True, use_fast_accum, False, False)
+ emulated_config = LinearMMConfig(
+ emulated_scaled_mm_config,
+ emulated_scaled_mm_config,
+ emulated_scaled_mm_config,
+ )
+ a_fp8 = hp_tensor_and_scale_to_float8(
+ a,
+ a_scale,
+ input_dtype,
+ emulated_config,
+ GemmInputRole.INPUT,
+ )
+ b_fp8 = hp_tensor_and_scale_to_float8(
+ b,
+ b_scale,
+ input_dtype,
+ emulated_config,
+ GemmInputRole.WEIGHT,
+ )
+ out_emualted = a_fp8 @ b_fp8
+ out_emualted.to(compare_type)
+
+ if base_dtype in {torch.bfloat16, torch.float16}:
+ atol, rtol = 7e-2, 7e-2
+ else:
+ atol, rtol = 2e-3, 2e-3
+ torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
+
+
+class TestNumerics:
+ @pytest.mark.parametrize(
+ "float8_dtype",
+ [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.float8_e4m3fnuz,
+ torch.float8_e5m2fnuz,
+ ],
+ )
+ @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+ def test_small_amax_float16(self, float8_dtype):
+ # If we calculate scale naively with FP8_MAX_POS / amax,
+ # the result may not be representable in fp16. Verify that
+ # the way we calculate scales actually works for tensors with
+ # small values.
+ #
+ # naive_s = fp8_max_pos / (amax + eps)
+ #
+ # failing case:
+ #
+ # fp8_max_pos / (amax + eps) >= fp16_max_pos, or
+ #
+ # amax + eps >= fp8_max_pos / fp16_max_pos
+
+ float8_max_pos = torch.finfo(float8_dtype).max
+ FP16_MAX_POS = torch.finfo(torch.float16).max
+
+ target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12)
+ x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
+ scale = tensor_to_scale(x, float8_dtype)
+ assert not torch.any(torch.isinf(scale))
+
+
+class TestFloat8LinearUtils(unittest.TestCase):
+ def test_swap_root_linear(self):
+ for emulate in [True, False]:
+ module = nn.Linear(3, 3)
+ config = Float8LinearConfig(emulate=emulate)
+ module = convert_to_float8_training(module, config=config)
+ self.assertIsInstance(module, Float8Linear)
+ self.assertEqual(module.linear_mm_config.output.emulate, emulate)
+ self.assertEqual(module.linear_mm_config.output.emulate, emulate)
+
+ def test_swap_root_linear_with_children_raises(self):
+ for emulate in [True, False]:
+ module = nn.Linear(3, 3)
+ module.child = nn.Sequential(nn.Linear(3, 3))
+ config = Float8LinearConfig(emulate=emulate)
+ with self.assertRaisesRegex(
+ AssertionError,
+ "Does not support a root nn.Linear with children",
+ ):
+ convert_to_float8_training(module, config=config)
+
+ def test_swap_submodule_linears(self):
+ class MLP(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin1 = nn.Linear(dim, 4 * dim)
+ self.lin2 = nn.Linear(4 * dim, dim)
+
+ for emulate in [True, False]:
+ model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
+ config = Float8LinearConfig(emulate=emulate)
+ model = convert_to_float8_training(model, config=config)
+ self.assertIsInstance(model[0].lin1, Float8Linear)
+ self.assertIsInstance(model[0].lin2, Float8Linear)
+ self.assertIsInstance(model[1], Float8Linear)
+ self.assertIsInstance(model[2].lin1, Float8Linear)
+ self.assertIsInstance(model[2].lin2, Float8Linear)
+
+ def test_swap_linears_with_filters(self):
+ class MLP(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin1 = nn.Linear(dim, 4 * dim)
+ self.lin2 = nn.Linear(4 * dim, 4 * dim)
+
+ model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
+ # filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
+
+ size_limit = 32
+
+ def module_filter_fn(mod, fqn):
+ return (
+ mod.in_features >= size_limit
+ and mod.out_features >= size_limit
+ and mod.in_features % 16 == 0
+ and mod.out_features % 16 == 0
+ )
+
+ config = Float8LinearConfig(emulate=True)
+ model = convert_to_float8_training(
+ model,
+ config=config,
+ module_filter_fn=module_filter_fn,
+ )
+ # in_features=8, out_features=32, 8 is less than 32.
+ self.assertNotIsInstance(model[0].lin1, Float8Linear)
+ # in_features=32, out_features=32,
+ self.assertIsInstance(model[0].lin2, Float8Linear)
+ # in_features=32, out_features=32,
+ self.assertIsInstance(model[1], Float8Linear)
+ # in_features=40, out_features=160, 40 is not divisible by 16.
+ self.assertNotIsInstance(model[2].lin1, Float8Linear)
+ # in_features=160, out_features=160,
+ self.assertIsInstance(model[2].lin2, Float8Linear)
+
+ def test_swap_submodule_linears_with_skip(self):
+ class MLP(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin1 = nn.Linear(dim, 4 * dim)
+ self.lin2 = nn.Linear(4 * dim, dim)
+
+ model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
+ module_filter_fn = lambda mod, fqn: fqn not in [
+ "0.lin2",
+ "2.lin1",
+ ]
+ config = Float8LinearConfig(emulate=True)
+ model = convert_to_float8_training(
+ model,
+ config=config,
+ module_filter_fn=module_filter_fn,
+ )
+ self.assertTrue(type(model[0].lin1) is Float8Linear)
+ self.assertTrue(type(model[0].lin2) is nn.Linear)
+ self.assertTrue(type(model[1]) is Float8Linear)
+ self.assertTrue(type(model[2].lin1) is nn.Linear)
+ self.assertTrue(type(model[2].lin2) is Float8Linear)
+
+ def test_fp8_tensor_statistics(self):
+ hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
+ lp_dtypes = (e4m3_dtype, e5m2_dtype)
+ for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
+ x1_hp = torch.ones(4, 4, dtype=hp_dtype)
+ tensor_len = x1_hp.numel()
+
+ # Overflow caused by a too large scaling factor
+ s_overflow = torch.tensor(1e9)
+ fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
+ (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
+ self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))
+
+ # Underflow caused by a too small scaling factor
+ s_underflow = torch.tensor(1e-9)
+ fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
+ (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
+ self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))
+
+ # Both overflow and underflow
+ x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
+ fp8_over_underflow = hp_tensor_and_scale_to_float8(
+ x2_hp, torch.tensor(1.0), lp_dtype
+ )
+ (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
+ self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len))
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py
new file mode 100644
index 000000000..1f3ebe169
--- /dev/null
+++ b/test/float8/test_compile.py
@@ -0,0 +1,329 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import copy
+import random
+import sys
+import unittest
+from io import StringIO
+
+import pytest
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+import torch
+import torch.nn as nn
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear import Float8Linear
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ get_float8_layers,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed
+from torchao.float8.float8_tensor import LinearMMConfig
+from torchao.float8.float8_utils import e4m3_dtype
+
+from torch._dynamo.test_case import TestCase as DynamoTestCase
+from torch._dynamo.testing import CompileCounterWithBackend
+
+is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
+
+
+def _test_compile_base(
+ backend: str,
+ fullgraph: bool,
+ config: Float8LinearConfig,
+ dtype: torch.dtype,
+):
+ random.seed(0)
+ torch.manual_seed(0)
+ x_shape = (16, 16)
+ linear_dtype = torch.bfloat16
+
+ x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
+ m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
+
+ m_fp8 = Float8Linear.from_float(
+ copy.deepcopy(m_ref),
+ config,
+ )
+
+ m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph)
+ m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph)
+ y_fp8 = m_fp8(x)
+ y_fp8.sum().backward()
+ y_ref = m_ref(x)
+ y_ref.sum().backward()
+ torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2)
+ torch.testing.assert_close(
+ m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1
+ )
+ torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2)
+
+
+@pytest.mark.parametrize("fullgraph", [True])
+@pytest.mark.parametrize(
+ "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
+@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+def test_eager_only(
+ fullgraph,
+ emulate: bool,
+ scaling_type_input: ScalingType,
+ scaling_type_weight: ScalingType,
+ scaling_type_grad_output: ScalingType,
+ dtype: torch.dtype,
+):
+ torch._dynamo.reset()
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ emulate=emulate,
+ )
+ _test_compile_base(
+ "eager",
+ fullgraph,
+ config,
+ dtype,
+ )
+
+
+@pytest.mark.parametrize("fullgraph", [True])
+@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
+@pytest.mark.parametrize(
+ "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
+@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
+def test_aot_eager(
+ fullgraph,
+ emulate: bool,
+ scaling_type_input: ScalingType,
+ scaling_type_weight: ScalingType,
+ scaling_type_grad_output: ScalingType,
+ dtype: torch.dtype,
+):
+ torch._dynamo.reset()
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ emulate=emulate,
+ )
+ _test_compile_base(
+ "aot_eager",
+ fullgraph,
+ config,
+ dtype,
+ )
+
+
+@pytest.mark.parametrize("fullgraph", [True])
+@pytest.mark.parametrize("emulate", [False])
+@pytest.mark.parametrize(
+ "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@pytest.mark.parametrize(
+ "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+)
+@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
+def test_inductor(
+ fullgraph,
+ emulate: bool,
+ scaling_type_input: ScalingType,
+ scaling_type_weight: ScalingType,
+ scaling_type_grad_output: ScalingType,
+ dtype: torch.dtype,
+):
+ torch._dynamo.reset()
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ emulate=emulate,
+ )
+ _test_compile_base(
+ "inductor",
+ fullgraph,
+ config,
+ dtype,
+ )
+
+
+class TestGraphBreaks(DynamoTestCase):
+ class MockLinear(torch.nn.Module):
+ def __init__(self, graph_break: bool):
+ super().__init__()
+ self.register_buffer("fp8_amax_x", torch.tensor(1.0))
+ self.register_buffer("fp8_scale_x", torch.tensor(1.0))
+ self.graph_break = graph_break
+
+ def forward(self, x):
+ x_fp8 = hp_tensor_to_float8_delayed(
+ x,
+ self.fp8_scale_x,
+ e4m3_dtype,
+ self.fp8_amax_x,
+ LinearMMConfig(),
+ )
+ if self.graph_break:
+ torch._dynamo.graph_break()
+ x_hp = x_fp8.to_original_precision()
+ return x_hp
+ return x_fp8
+
+ @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+ def test_float8_with_graph_break_in_the_middle(self):
+ """Test that having Float8Tensor object at the boundary of a subgraph"""
+ cnts = CompileCounterWithBackend("inductor")
+ mod = self.MockLinear(graph_break=True).cuda()
+ compiled_mod = copy.deepcopy(mod)
+ compiled_mod = torch.compile(compiled_mod, backend=cnts)
+ x = torch.randn(16, 16, device="cuda")
+ y_eager = mod(x)
+ y_compiled = compiled_mod(x)
+ self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
+ torch.testing.assert_close(y_eager, y_compiled)
+
+ @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+ def test_float8_graph_input(self):
+ """Test that having Float8Tensor object as a graph input"""
+
+ def to_float(x):
+ return x.to_original_precision()
+
+ cnts = CompileCounterWithBackend("inductor")
+ mod = self.MockLinear(graph_break=False).cuda()
+ x = torch.randn(2, 2, device="cuda")
+ compiled_to_float = torch.compile(to_float, backend=cnts)
+ y = mod(x)
+ y2_eager = to_float(y)
+ y2_compiled = compiled_to_float(y)
+ self.assertEqual(
+ cnts.frame_count,
+ 1,
+ "to_float was not compiled into 1 frame and likely encountered a skip!",
+ )
+ torch.testing.assert_close(y2_eager, y2_compiled)
+
+ @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+ def test_float8_graph_output(self):
+ """Test that having Float8Tensor object as a graph output works"""
+ cnts = CompileCounterWithBackend("inductor")
+ mod = self.MockLinear(graph_break=False).cuda()
+ compiled_mod = torch.compile(mod, backend=cnts)
+ x = torch.randn(16, 16, device="cuda")
+ y_compiled = compiled_mod(x)
+
+ self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
+ tensors, ctx = y_compiled.__tensor_flatten__()
+ for tensor in tensors:
+ assert not isinstance(
+ getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
+ ), "Float8Tensor should not contain any FakeTensors!"
+ assert isinstance(
+ y_compiled._orig_dtype, torch.dtype
+ ), "Float8Tensor._orig_dtype should be a dtype but got {}".format(
+ type(y_compiled._orig_dtype)
+ )
+ assert isinstance(
+ y_compiled._linear_mm_config.output.emulate, bool
+ ), "Float8Tensor._emulate should be a bool but got {}".format(
+ type(y_compiled._linear_mm_config.output.emulate)
+ )
+
+
+@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+def test_sync_amax_func():
+ torch._dynamo.reset()
+ cnts = CompileCounterWithBackend("inductor")
+ module = torch.nn.Sequential(
+ nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
+ )
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+ )
+ float8_mod = convert_to_float8_training(
+ module,
+ config=config,
+ )
+ compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
+ compiled_swap_func(float8_mod)
+ assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
+
+
+class capture_stderr(list):
+ """
+ Replace sys.stderr with a temporary StringIO
+ """
+
+ def __enter__(self):
+ self.sys_stderr = sys.stderr
+ self.stringio = StringIO()
+ sys.stderr = self.stringio
+ return self
+
+ def __exit__(self, *args):
+ self.append(str(self.stringio.getvalue()))
+ del self.stringio
+ sys.stderr = self.sys_stderr
+
+
+@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
+def test_sync_amax_func_cuda_graph_success():
+ torch._dynamo.reset()
+ with capture_stderr() as stderr:
+ my_module = nn.Sequential(
+ nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
+ ).to("cuda")
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+ )
+ convert_to_float8_training(
+ my_module,
+ config=config,
+ )
+ inpt = torch.randn(
+ 16, 16, device="cuda", dtype=torch.float32, requires_grad=True
+ )
+ sync_func = torch.compile(
+ sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True
+ )
+ fp8_layers = get_float8_layers(my_module)
+ my_module(inpt)
+ sync_func(my_module, fp8_layers)
+
+ assert "skipping cudagraphs due to mutaton on input" not in stderr[0]
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py
new file mode 100644
index 000000000..7a6c9125d
--- /dev/null
+++ b/test/float8/test_dtensor.py
@@ -0,0 +1,327 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Test numerics of manually defined float16 TP vs float8 TP of toy models
+
+Note: for now, this does not run in CI.
+TODO(future): make this run in CI
+"""
+
+import copy
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import pytest
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+from torchao.float8 import Float8LinearConfig
+from torchao.float8.float8_linear_utils import convert_to_float8_training
+
+from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+)
+from torchao.float8.float8_tensor_parallel import (
+ Float8ColwiseParallel,
+ Float8RowwiseParallel,
+ PrepareFloat8ModuleInput,
+)
+from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
+from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
+from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import parallelize_module
+from tqdm import tqdm
+
+
+def setup_distributed():
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
+ device_mesh = init_device_mesh("cuda", (world_size,))
+ # seed must be the same in all processes
+ torch.manual_seed(1)
+ return device_mesh
+
+
+class FeedForward(nn.Module):
+ """MLP based model"""
+
+ def __init__(self):
+ super(FeedForward, self).__init__()
+ self.w1 = nn.Linear(16, 32, bias=False)
+ self.w2 = nn.Linear(16, 32, bias=False)
+ self.out_proj = nn.Linear(32, 16, bias=False)
+
+ def forward(self, x):
+ return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
+
+
+class ToyModel(nn.Module):
+ def __init__(self):
+ super(ToyModel, self).__init__()
+ self.ffn = FeedForward()
+
+ def forward(self, x):
+ return self.ffn(x)
+
+
+def _test_scaled_mm(mesh: DeviceMesh, size=16):
+ device = mesh.device_type
+ fp8_dtype = e4m3_dtype
+ world_size = mesh.size()
+
+ x_fp32 = torch.rand(size, size, device=device)
+ y_fp32 = torch.eye(size, device=device).t()
+
+ placement_combs = (
+ (Shard(0), Replicate()),
+ (Replicate(), Shard(1)),
+ (Shard(1), Shard(0)),
+ )
+ expected_dt_out_shape = (
+ (size * world_size, size),
+ (size, size * world_size),
+ (size, size),
+ )
+ for idx, (lhs_placement, rhs_placement) in enumerate(placement_combs):
+ x_scale = tensor_to_scale(x_fp32, fp8_dtype).float()
+ y_scale = tensor_to_scale(y_fp32, fp8_dtype).float()
+
+ x_fp8 = hp_tensor_and_scale_to_float8(
+ x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT
+ )
+ y_fp8 = hp_tensor_and_scale_to_float8(
+ y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT
+ )
+
+ dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
+ dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False)
+
+ assert isinstance(dist_x_fp8.to_local(), Float8Tensor)
+ assert isinstance(dist_y_fp8.to_local(), Float8Tensor)
+ assert dist_x_fp8.to_local()._orig_dtype == torch.float32
+ out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8)
+ local_fp8_out = out_fp8.to_local()
+ assert out_fp8.shape == expected_dt_out_shape[idx], (idx, local_fp8_out.shape)
+
+ # after mm the out dtype should be fp32
+ assert local_fp8_out.dtype == torch.float32
+
+
+def _test_fp8_redistribute(mesh: DeviceMesh, size=16):
+ device = mesh.device_type
+ fp8_dtype = e4m3_dtype
+ world_size = mesh.size()
+
+ x_fp32 = torch.rand(size, size, device=device)
+
+ x_scale = tensor_to_scale(x_fp32, fp8_dtype).float()
+
+ x_fp8 = hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype)
+
+ dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False)
+ out_dist = dist_x_fp8.redistribute(placements=[Replicate()])
+ assert out_dist.shape == (size * world_size, size)
+ assert out_dist.placements == (Replicate(),)
+ out_local = out_dist.to_local()
+ # after allgather the out shape should be replicate
+ assert out_local.shape == (size * world_size, size)
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
+
+ if isinstance(out_local, AsyncCollectiveTensor):
+ out_local = out_local.wait()
+
+ assert isinstance(out_local, Float8Tensor)
+ assert out_local._data.dtype == fp8_dtype
+
+
+def _test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
+ device = mesh.device_type
+ fp8_dtype = e4m3_dtype
+
+ x_fp32 = torch.rand(size, size, device=device)
+ dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
+
+ dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()
+ assert isinstance(dist_x_scale, DTensor)
+
+ dist_x_fp8 = hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype)
+ assert isinstance(dist_x_fp8, DTensor)
+
+
+def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
+ device = mesh.device_type
+ fp8_dtype = e4m3_dtype
+
+ x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
+ local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)
+ target = torch.rand(size, 2 * size, device=device)
+
+ dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
+ dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()
+
+ dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)])
+ dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float()
+ dist_target = distribute_tensor(target, mesh, [Shard(0)])
+
+ dist_x_fp8 = hp_tensor_and_scale_to_float8(
+ dist_x_fp32,
+ dist_x_scale,
+ fp8_dtype,
+ None,
+ GemmInputRole.INPUT,
+ )
+ dist_weight_fp8 = hp_tensor_and_scale_to_float8(
+ dist_wight_fp32,
+ dist_weight_scale,
+ fp8_dtype,
+ None,
+ GemmInputRole.WEIGHT,
+ )
+
+ out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
+ out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
+ assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
+ loss = torch.sum(torch.abs(out - dist_target))
+ loss.backward()
+
+
+def _test_fp8_mlp_tensor_parallelism_base(
+ mesh: DeviceMesh, size=16, compile: bool = False
+):
+ device = mesh.device_type
+ # For now, only supports dynamic scaling of `x` and `dL_dY`.
+ # TODO(future): add support for float8 all-gather with delayed scaling
+ # for activations and gradients.
+ config = Float8LinearConfig(emulate=True)
+
+ toy_model = ToyModel().to(device)
+ toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
+
+ tp_model = copy.deepcopy(toy_model)
+ tp_model = convert_to_float8_training(tp_model, config=config)
+ sp_model = copy.deepcopy(toy_model)
+ sp_model = convert_to_float8_training(sp_model, config=config)
+
+ # vanilla TP
+ tp_model = parallelize_module(
+ tp_model,
+ mesh,
+ {
+ "ffn.w1": Float8ColwiseParallel(),
+ "ffn.w2": Float8ColwiseParallel(),
+ "ffn.out_proj": Float8RowwiseParallel(),
+ },
+ )
+
+ # "sequence parallel" mlp computation
+ sp_model = parallelize_module(
+ sp_model,
+ mesh,
+ {
+ "ffn": PrepareFloat8ModuleInput(
+ input_layouts=Shard(1), desired_input_layouts=Replicate()
+ ),
+ "ffn.w1": Float8ColwiseParallel(),
+ "ffn.w2": Float8ColwiseParallel(),
+ "ffn.out_proj": Float8RowwiseParallel(
+ output_layouts=Shard(1), use_local_output=False
+ ),
+ },
+ )
+
+ # PrepareFloat8ModuleInput with specific submodule fqn
+ sp_model2 = copy.deepcopy(toy_model)
+ sp_model2 = convert_to_float8_training(sp_model2, config=config)
+
+ sp_model2 = parallelize_module(
+ sp_model2,
+ mesh,
+ {
+ "ffn": PrepareFloat8ModuleInput(
+ input_layouts=Shard(1),
+ desired_input_layouts=Replicate(),
+ fwd_config_submodule_fqn="w2",
+ ),
+ "ffn.w1": Float8ColwiseParallel(),
+ "ffn.w2": Float8ColwiseParallel(),
+ "ffn.out_proj": Float8RowwiseParallel(
+ output_layouts=Shard(1), use_local_output=False
+ ),
+ },
+ )
+
+ if compile:
+ tp_model = torch.compile(tp_model)
+ sp_model = torch.compile(sp_model)
+ sp_model2 = torch.compile(sp_model2)
+
+ x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
+ x_fp32_tp_input = x_fp32.clone()
+ x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
+
+ tp_out = tp_model(x_fp32_tp_input)
+ tp_out.sum().backward()
+ sp_out = sp_model(x_fp32_sp_input)
+ sp_out.sum().backward()
+ global_out = toy_model_fp8(x_fp32)
+ global_out.sum().backward()
+ torch.testing.assert_close(tp_out, global_out)
+ torch.testing.assert_close(sp_out.full_tensor(), global_out)
+ torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
+ torch.testing.assert_close(
+ tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
+ )
+
+ sp_out2 = sp_model2(x_fp32_sp_input)
+ sp_out2.sum().backward()
+ torch.testing.assert_close(sp_out2.full_tensor(), global_out)
+ torch.testing.assert_close(
+ tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
+ )
+ torch.testing.assert_close(
+ tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
+ )
+
+
+def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
+ _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)
+
+
+def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
+ _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
+
+
+if __name__ == "__main__":
+ # float8 only works on CUDA H100 so we only test cuda and we follow
+ # other test files to not use TestCase but instead just add the test
+ # cases in the main func.
+ device_mesh = setup_distributed()
+ tests = [
+ _test_scaled_mm,
+ _test_fp8_redistribute,
+ _test_dtensor_cast_to_fp8,
+ _test_dtensor_fp8_autograd,
+ _test_fp8_mlp_tensor_parallelism_eager,
+ _test_fp8_mlp_tensor_parallelism_compile,
+ ]
+
+ for test in tqdm(tests, desc="Running tests"):
+ try:
+ test(device_mesh)
+ except Exception as e:
+ print(f"Test {test.__name__} failed with error: {e}")
+ raise e
+
+ torch.distributed.destroy_process_group()
diff --git a/test/float8/test_dtensor.sh b/test/float8/test_dtensor.sh
new file mode 100755
index 000000000..2e38feffe
--- /dev/null
+++ b/test/float8/test_dtensor.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# terminate script on first error
+set -e
+
+if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
+ echo "Skipping test_dtensor.sh because no CUDA devices are available."
+ exit
+fi
+
+NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py
diff --git a/test/float8/test_everything.sh b/test/float8/test_everything.sh
new file mode 100755
index 000000000..d70833323
--- /dev/null
+++ b/test/float8/test_everything.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# terminate script on first error
+set -e
+IS_ROCM=$(rocm-smi --version || true)
+
+pytest test/float8/test_base.py
+pytest test/float8/test_compile.py
+pytest test/float8/test_inference_flows.py
+pytest test/float8/test_numerics_integration.py
+
+# These tests do not work on ROCm yet
+if [ -z "$IS_ROCM" ]
+then
+./test/float8/test_fsdp.sh
+./test/float8/test_fsdp_compile.sh
+./test/float8/test_dtensor.sh
+pytest test/float8/test_fsdp2/test_fsdp2.py
+fi
+
+echo "all tests successful"
diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py
new file mode 100644
index 000000000..f30878b33
--- /dev/null
+++ b/test/float8/test_fsdp.py
@@ -0,0 +1,212 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Test numerics of bf16 versus float8 with FSDP on. At a high level:
+1. start with a reference model, with FSDP on
+2. run forward + backward + optim for 2 iterations
+3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling)
+4. compare outputs and state dict between (2) and (3), should be close
+"""
+
+import copy
+import os
+import pytest
+import warnings
+
+import fire
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_utils import compute_error
+from torch.distributed.fsdp import (
+ FullStateDictConfig,
+ FullyShardedDataParallel as FSDP,
+ StateDictType,
+)
+
+torch.manual_seed(0)
+
+B, M, K, N = 8, 8, 32, 32
+lr = 0.01
+N_ITER = 2
+
+
+def setup(rank, world_size):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def cleanup():
+ dist.destroy_process_group()
+
+
+def get_model(K, N, base_dtype=torch.float32):
+ m = nn.Sequential(
+ nn.Linear(K, N, dtype=base_dtype),
+ nn.ReLU(),
+ nn.Linear(N, N, dtype=base_dtype),
+ nn.ReLU(),
+ )
+ return m
+
+
+# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
+# and modified
+def fsdp_main(rank, world_size, args):
+ setup(rank, world_size)
+ torch.cuda.set_device(rank)
+
+ emulate, base_dtype, compile, use_weight_dynamic_scaling = args
+ model = get_model(K, N, base_dtype=base_dtype).to(rank)
+ model_fp8 = copy.deepcopy(model)
+
+ scaling_type_weight = (
+ ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED
+ )
+ config = Float8LinearConfig(
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ # TODO(future): delete this arg as it's always False
+ emulate=False,
+ )
+
+ # Note: we only iterate over `scaling_type_weight` because FSDP only interacts
+ # with weights.
+ convert_to_float8_training(
+ model_fp8,
+ config=config,
+ )
+
+ # To compile FSDP, we need use_orig_params to True
+ model = FSDP(model, use_orig_params=True)
+ model_fp8 = FSDP(model_fp8, use_orig_params=True)
+ # TODO: The following line doesn't work. We should fix it.
+ # model = FSDP(torch.compile(model), use_orig_params=True)
+
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+ optimizer_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)
+
+ # Note: we need two different inputs to properly measure the impact of
+ # delayed scaling, before the first input uses dynamic scaling to
+ # populate the buffers
+ ref_input_global = [
+ torch.randn(B, M, K).cuda().to(base_dtype),
+ torch.randn(B, M, K).cuda().to(base_dtype),
+ ]
+ ref_grad_global = [
+ torch.randn(B, M, N).cuda().to(base_dtype),
+ torch.randn(B, M, N).cuda().to(base_dtype),
+ ]
+ ref_input_local = []
+ ref_grad_local = []
+
+ # basic distributed data sampling
+ assert B % world_size == 0
+ bsz_local_start = int(rank / world_size * B)
+ bsz_local_end = int((rank + 1) / world_size * B)
+ for idx in range(N_ITER):
+ ref_input_local.append(
+ ref_input_global[idx][bsz_local_start:bsz_local_end].to(rank)
+ )
+ ref_grad_local.append(
+ ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank)
+ )
+
+ sync_float8_func = sync_float8_amax_and_scale_history
+ if compile:
+ sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)
+
+ def forward_backward(model, optim, is_fp8, i):
+ optim.zero_grad()
+ y_local = model(ref_input_local[i])
+ y_local.backward(ref_grad_local[i])
+ if is_fp8 and linear_requires_sync(config):
+ sync_float8_func(model)
+ optim.step()
+ return y_local
+
+ for i in range(N_ITER):
+ # We first run one iteration without compile, as a workaround to compile float8 layer.
+ # In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
+ # After that, float8 layers go the the branches of "self.is_amax_initialized == True"
+ # TODO: Need to fix compile to run wihtout this workaround.
+ if i == 1 and compile:
+ model = torch.compile(model)
+ model_fp8 = torch.compile(model_fp8)
+ y_local = forward_backward(model, optimizer, is_fp8=False, i=i)
+ y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i)
+ local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841
+
+ # get global y
+ y_global = [
+ torch.zeros(*y_local.shape, dtype=base_dtype).to(rank)
+ for r in range(world_size)
+ ]
+ dist.all_gather(y_global, y_local)
+ y_global = torch.cat(y_global, dim=0)
+ y_global_fp8 = [
+ torch.zeros(*y_local_fp8.shape, dtype=base_dtype).to(rank)
+ for r in range(world_size)
+ ]
+ dist.all_gather(y_global_fp8, y_local_fp8)
+ y_global_fp8 = torch.cat(y_global_fp8, dim=0)
+ if rank == 0:
+ sqnr = compute_error(y_global, y_global_fp8)
+ assert sqnr > 15.0, f"SQNR of {sqnr} is too low"
+
+ # get global state dict
+ # https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html
+ dist.barrier()
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
+ cpu_state = model.state_dict()
+ with FSDP.state_dict_type(model_fp8, StateDictType.FULL_STATE_DICT, save_policy):
+ cpu_state_fp8 = model_fp8.state_dict()
+ if rank == 0:
+ for k, v1 in cpu_state.items():
+ v2 = cpu_state_fp8[k]
+ v1, v2 = v1.cpu(), v2.cpu()
+ sqnr = compute_error(v1, v2)
+ assert sqnr > 15.0, f"SQNR of {sqnr} is too low, k: {k}, v1: {v1}, v2: {v2}"
+
+ cleanup()
+
+
+def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
+ base_dtype = torch.bfloat16
+
+ emulate = False
+ if not torch.cuda.is_available():
+ warnings.warn("CUDA not available, running in emulation_mode")
+ emulate = True
+ elif torch.cuda.get_device_capability() < (9, 0):
+ warnings.warn(
+ f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
+ )
+ emulate = True
+
+ WORLD_SIZE = torch.cuda.device_count()
+ args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling)
+ mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
+
+
+if __name__ == "__main__":
+ fire.Fire(run)
diff --git a/test/float8/test_fsdp.sh b/test/float8/test_fsdp.sh
new file mode 100755
index 000000000..3ff19d917
--- /dev/null
+++ b/test/float8/test_fsdp.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# terminate script on first error
+set -e
+
+launch() {
+ echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING"
+
+ # the NCCL_DEBUG setting is to avoid log spew
+ # the CUDA_VISIBLE_DEVICES setting is for easy debugging
+ NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \
+ --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING
+
+ echo "✅ All Tests Passed ✅"
+}
+
+if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
+ echo "Skipping test_fsdp.sh because no CUDA devices are available."
+ exit
+fi
+
+# COMPILE, USE_WEIGHT_DYNAMIC_SCALING
+for i in False,False False,True True,False True,True
+do
+ IFS=","; set -- $i;
+ COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2
+ launch
+done
diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/test/float8/test_fsdp2/fsdp2_common.py
new file mode 100644
index 000000000..333206ba4
--- /dev/null
+++ b/test/float8/test_fsdp2/fsdp2_common.py
@@ -0,0 +1,89 @@
+import contextlib
+from typing import List, Optional
+
+import torchao.float8.config as config
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torchao.float8.config import Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
+
+
+def check_parity_no_mp(
+ test_cls,
+ ref_model: nn.Module,
+ ref_optim: torch.optim.Optimizer,
+ fsdp_model: nn.Module,
+ fsdp_optim: torch.optim.Optimizer,
+ local_inp: torch.Tensor,
+ precompute: bool = False,
+ config: Optional[Float8LinearConfig] = None,
+ compile_transformer_block: bool = False,
+):
+ # TODO(before land): reorder args and make config not optional
+ for iter_idx in range(10):
+ losses: List[torch.Tensor] = []
+ for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
+ optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
+ losses.append(model(local_inp).sum())
+ losses[-1].backward()
+ if model is ref_model:
+ for param in model.parameters():
+ dist.all_reduce(param.grad)
+ param.grad.div_(dist.get_world_size())
+
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(model)
+
+ optim.step()
+ if (
+ model is fsdp_model
+ and precompute
+ and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
+ ):
+ precompute_float8_dynamic_scale_for_fsdp(model)
+
+ if compile_transformer_block:
+ test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
+ else:
+ test_cls.assertEqual(losses[0], losses[1])
+
+
+def check_parity_bf16_mp(
+ test_cls,
+ ref_model: nn.Module,
+ ref_model_bf16: nn.Module,
+ ref_optim: torch.optim.Optimizer,
+ fsdp_model: nn.Module,
+ fsdp_optim: torch.optim.Optimizer,
+ local_inp: torch.Tensor,
+):
+ for iter_idx in range(10):
+ losses: List[torch.Tensor] = []
+ for model, optim in (
+ (ref_model_bf16, ref_optim),
+ (fsdp_model, fsdp_optim),
+ ):
+ optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
+ losses.append(model(local_inp).sum())
+ losses[-1].backward()
+ if model is ref_model_bf16:
+ for param_bf16, param_fp32 in zip(
+ ref_model_bf16.parameters(), ref_model.parameters()
+ ):
+ dist.all_reduce(param_bf16.grad)
+ param_bf16.grad.div_(dist.get_world_size())
+ param_fp32.grad = param_bf16.grad.float()
+ param_bf16.grad = None
+ # TODO(future): add amax syncing once delayed scaling is supported
+ optim.step()
+ for param_fp32, param_bf16 in zip(
+ ref_model.parameters(), ref_model_bf16.parameters()
+ ):
+ param_bf16.detach().copy_(param_fp32)
+ test_cls.assertEqual(losses[0], losses[1])
diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py
new file mode 100644
index 000000000..7004b3a1c
--- /dev/null
+++ b/test/float8/test_fsdp2/test_fsdp2.py
@@ -0,0 +1,561 @@
+import copy
+import itertools
+import pytest
+import threading
+import unittest
+from typing import Any, List
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+
+import torch
+import torch._dynamo.testing
+import torch.distributed as dist
+import torch.nn as nn
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import convert_to_float8_training
+from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
+from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
+from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
+from torch.distributed._tensor import DTensor
+from torch.testing._internal.common_cuda import TEST_CUDA
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import (
+ FSDPTest,
+ FSDPTestMultiThread,
+ MLP,
+ patch_all_gather,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ ModelArgs,
+ Transformer,
+ TransformerBlock,
+)
+
+is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
+if not is_H100:
+ pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
+
+class TestFloat8Common:
+ def broadcast_module(self, module: nn.Module) -> None:
+ # Broadcast for multi-threaded process group tests since seed is per
+ # process, not per thread
+ for param in module.parameters():
+ dist.broadcast(param, src=0)
+
+ def init_single_module(self) -> nn.Module:
+ torch.manual_seed(42)
+ module = nn.Linear(16, 16, device="cuda")
+ self.broadcast_module(module)
+ return module
+
+ def init_multi_module(self) -> nn.Module:
+ torch.manual_seed(42)
+ module = nn.Sequential(*[MLP(16, device="cuda") for _ in range(3)])
+ self.broadcast_module(module)
+ return module
+
+ def init_transformer(self, weight_tying: bool) -> nn.Module:
+ torch.manual_seed(42)
+ args = ModelArgs(
+ n_layers=3,
+ dim=768,
+ n_heads=12,
+ dropout_p=0.0,
+ weight_tying=weight_tying,
+ vocab_size=32,
+ )
+ module = Transformer(args).cuda()
+ self.broadcast_module(module)
+ return module
+
+ def get_local_inp(self, dtype: torch.dtype = torch.float32):
+ torch.manual_seed(42)
+ global_inp = torch.randn((16 * self.world_size, 16), device="cuda", dtype=dtype)
+ dist.broadcast(global_inp, src=0)
+ return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)
+
+
+class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
+ @property
+ def world_size(self) -> int:
+ return min(torch.cuda.device_count(), 2)
+
+ @skip_if_lt_x_gpu(2)
+ def test_transformer_parity(self):
+ self.run_subtests(
+ {
+ "enable_fsdp_float8_all_gather": [False, True],
+ "precompute": [False, True],
+ "scaling_type_weight": [
+ ScalingType.DYNAMIC,
+ ScalingType.DELAYED,
+ ],
+ "compile_transformer_block": [False, True],
+ },
+ self._test_transformer_parity,
+ )
+
+ def _test_transformer_parity(
+ self,
+ enable_fsdp_float8_all_gather: bool,
+ precompute: bool,
+ scaling_type_weight: ScalingType,
+ compile_transformer_block: bool,
+ ):
+ if not enable_fsdp_float8_all_gather and precompute:
+ return
+ elif scaling_type_weight is ScalingType.DELAYED and precompute:
+ return
+
+ # NOTE: Weight-tying does not compose with fp8 all-gather because the
+ # embedding weight and output linear weight are tied but only the
+ # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
+ # fp8 for that tied weight, incorrectly using fp8 for the embedding.
+ weight_tying = not enable_fsdp_float8_all_gather
+ module = self.init_transformer(weight_tying=weight_tying).cuda()
+ ref_module = copy.deepcopy(module)
+ float8_linear_config1 = Float8LinearConfig(
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ convert_to_float8_training(
+ ref_module,
+ config=float8_linear_config1,
+ )
+ if compile_transformer_block:
+ for layer_id, transformer_block in ref_module.layers.named_children():
+ transformer_block = torch.compile(transformer_block, dynamic=False)
+ ref_module.layers.register_module(layer_id, transformer_block)
+ float8_linear_config2 = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ convert_to_float8_training(
+ module,
+ config=float8_linear_config2,
+ )
+ for layer_id, transformer_block in module.layers.named_children():
+ if compile_transformer_block:
+ transformer_block = torch.compile(transformer_block, dynamic=False)
+ fully_shard(transformer_block)
+ module.layers.register_module(layer_id, transformer_block)
+ fully_shard(module)
+ ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
+ optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
+ local_inp = torch.randint(
+ 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
+ )
+ check_parity_no_mp(
+ self,
+ ref_module,
+ ref_optim,
+ module,
+ optim,
+ local_inp,
+ precompute,
+ config=float8_linear_config2,
+ compile_transformer_block=compile_transformer_block,
+ )
+
+ @skip_if_lt_x_gpu(2)
+ def test_transformer_memory(self):
+ """Tests peak active memory in the forward and backward passes."""
+ for enable_fsdp_float8_all_gather in [False, True]:
+ self._test_transformer_memory(enable_fsdp_float8_all_gather)
+
+ def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool):
+ torch.manual_seed(42)
+ # Pre-run a linear forward (gemm and bias) and backward (gemm) to
+ # allocate the cuBLAS workspaces before measuring the memory usage
+ # since the workspace size can differ between hardwares
+ lin = torch.nn.Linear(768, 768, device="cuda")
+ inp = torch.randn(1, 768, device="cuda")
+ lin(inp).sum().backward()
+ torch.cuda.empty_cache()
+ base_mem_mb = self._get_peak_active_memory_mb()
+
+ vocab_size = 32
+ model_args = ModelArgs(
+ vocab_size=vocab_size,
+ n_layers=3,
+ dim=768,
+ n_heads=12,
+ weight_tying=False,
+ )
+ model = Transformer(model_args)
+ # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
+ # requirement to use a smaller activation size
+ float8_linear_config = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
+ emulate=True,
+ )
+ convert_to_float8_training(model, config=float8_linear_config)
+ model_unsharded_numel = sum(p.numel() for p in model.parameters())
+ model_sharded_numel = (model_unsharded_numel + 1) // 2
+ block_lin_weight_numel = 0
+ block_other_numel = 0
+ for module in model.layers[0].modules():
+ for param in module.parameters(recurse=False):
+ if isinstance(module, nn.Linear):
+ block_lin_weight_numel += param.numel()
+ else:
+ block_other_numel += param.numel()
+ non_block_numel = round(
+ sum(p.numel() for p in model.tok_embeddings.parameters())
+ + sum(p.numel() for p in model.pos_embeddings.parameters())
+ + sum(p.numel() for p in model.norm.parameters())
+ + sum(p.numel() for p in model.output.parameters())
+ )
+ for module in model.modules():
+ if isinstance(module, TransformerBlock):
+ fully_shard(module)
+ fully_shard(model)
+
+ # Init: Each module is moved to GPU before sharding parameters
+ peak_mem_mb = self._get_peak_active_memory_mb()
+ curr_mem_mb = self._get_curr_active_memory_mb()
+ init_mem_mb = (
+ (model_sharded_numel + block_lin_weight_numel + block_other_numel) * 4 / 1e6
+ )
+ # Allow for some buffer for the peak memory since original parameters
+ # are not freed until a `fully_shard` call returns
+ buffer_mb = 4
+ self.assertLessEqual(peak_mem_mb - base_mem_mb, init_mem_mb + buffer_mb)
+ self.assertLessEqual(curr_mem_mb - base_mem_mb, init_mem_mb)
+
+ # Use a small input to minimize activation memory usage
+ inp = torch.randint(0, vocab_size, (1, 4), device="cuda")
+
+ # Forward:
+ loss = model(inp)
+ mem_mb = self._get_peak_active_memory_mb()
+ # Allow for some buffer for fragmentation/activations (where this
+ # number is kept much smaller than the actual memory usage, which is on
+ # the order of 100-200+ MB)
+ buffer_mb = 16
+ if enable_fsdp_float8_all_gather:
+ # Non-block parameters (fp32), 3x block non-linear-weight
+ # parameters (fp32) and block linear-weight parameters (fp8)
+ # (current all-gather, copy-out, and next all-gather), and other
+ expected_mem_mb = (
+ (non_block_numel * 4)
+ + 3 * (block_lin_weight_numel + block_other_numel * 4)
+ ) / 1e6 + buffer_mb
+ else:
+ # Non-block parameters (fp32), 3x block parameters (fp32)
+ # (current all-gather, copy-out, and next all-gather), Nx block
+ # linear-weight parameters (fp8) for N blocks (saved by autograd),
+ # and other
+ expected_mem_mb = (
+ (non_block_numel + 3 * (block_lin_weight_numel + block_other_numel)) * 4
+ + model_args.n_layers * block_lin_weight_numel
+ ) / 1e6 + buffer_mb
+ # Sharded parameters
+ expected_mem_mb += model_sharded_numel * 4 / 1e6
+ self.assertLessEqual(mem_mb, expected_mem_mb + base_mem_mb)
+
+ # Backward:
+ loss.sum().backward()
+ mem_mb = self._get_peak_active_memory_mb()
+ if enable_fsdp_float8_all_gather:
+ # Non-block parameters (fp32), 2x block non-linear weight
+ # parameters (fp32) and block linear-weight parameters (fp8)
+ # (current copy-out and next all-gather), 1x block gradients (fp32)
+ expected_mem_mb = (
+ (non_block_numel * 4)
+ + 2 * (block_lin_weight_numel + block_other_numel * 4)
+ + 1 * (block_lin_weight_numel + block_other_numel) * 4
+ ) / 1e6 + buffer_mb
+ else:
+ # Non-block parameters (fp32), 3x block parameters (fp32) (current
+ # copy-out, next all-gather, current gradients)
+ expected_mem_mb = (
+ non_block_numel + 3 * (block_lin_weight_numel + block_other_numel) * 4
+ ) * 4 / 1e6 + buffer_mb
+ # 2x sharded parameters/gradients
+ expected_mem_mb += 2 * model_sharded_numel * 4 / 1e6
+ self.assertLessEqual(mem_mb, expected_mem_mb + base_mem_mb)
+
+ def _get_peak_active_memory_mb(self) -> int:
+ mem_stats = torch.cuda.memory_stats()
+ return round(mem_stats["active_bytes.all.peak"] / 1e6)
+
+ def _get_curr_active_memory_mb(self) -> int:
+ mem_stats = torch.cuda.memory_stats()
+ return round(mem_stats["active_bytes.all.current"] / 1e6)
+
+
+class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
+ @property
+ def world_size(self) -> int:
+ return 2
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_weight_subclass_dynamic(self):
+ tensor_cls = WeightWithDynamicFloat8CastTensor
+ # Check for a single FSDP paramter group
+ module_fp32 = self.init_single_module()
+ float8_linear_config = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=True,
+ emulate=True,
+ )
+ module = convert_to_float8_training(
+ module_fp32,
+ config=float8_linear_config,
+ )
+ self.assertIsInstance(module.weight, tensor_cls)
+ fully_shard(module)
+ for param_name, param in module.named_parameters():
+ self.assertIsInstance(param, DTensor)
+ if "weight" in param_name:
+ self.assertIsInstance(param.to_local(), tensor_cls)
+
+ # Check for multiple FSDP paramter groups
+ module = self.init_multi_module()
+ module = convert_to_float8_training(
+ module,
+ config=float8_linear_config,
+ )
+ for param_name, param in module.named_parameters():
+ if "weight" in param_name:
+ self.assertIsInstance(param, tensor_cls)
+ for mlp in module:
+ fully_shard(mlp)
+ fully_shard(module)
+ for param_name, param in module.named_parameters():
+ self.assertIsInstance(param, DTensor)
+ if "weight" in param_name:
+ self.assertIsInstance(param.to_local(), tensor_cls)
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_fp8_fp32_all_gather_dynamic_comm_size(self):
+ """
+ Tests that fp8 all-gather with dynamic scaling communicates the
+ expected number of bytes.
+ """
+ orig_all_gather = dist.all_gather_into_tensor
+ all_gather_sizes: List[int] = []
+ lock = threading.Lock()
+
+ def all_gather(*args: Any, **kwargs: Any):
+ nonlocal all_gather_sizes
+ if len(args) > 0:
+ output = args[0]
+ elif "output_tensor" in kwargs:
+ output = kwargs["output_tensor"]
+ else:
+ raise AssertionError(
+ f"Cannot get all-gather output from\nargs: {args}\nkwargs: {kwargs}"
+ )
+ with lock:
+ all_gather_sizes.append(output.numel() * output.itemsize)
+ return orig_all_gather(*args, **kwargs)
+
+ def get_expected_all_gather_size(module: nn.Module):
+ size = 0
+ for param_name, param in module.named_parameters():
+ bytes_per_numel = 1 if "weight" in param_name else param.itemsize
+ size += param.numel() * bytes_per_numel
+ return size
+
+ # - Check for a single FSDP parameter group
+ module_fp32 = self.init_single_module()
+ ref_module = copy.deepcopy(module_fp32)
+ float8_linear_config = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=True,
+ )
+ module_fp32 = convert_to_float8_training(
+ module_fp32, config=float8_linear_config
+ )
+ module = module_fp32
+ fully_shard(module)
+ local_inp = self.get_local_inp()
+ expected_all_gather_size = get_expected_all_gather_size(ref_module)
+ with patch_all_gather(all_gather):
+ out = module(local_inp)
+ # For MPTG, one rank runs all all-gathers, each of the same size
+ if all_gather_sizes:
+ self.assertEqual(len(all_gather_sizes), self.world_size)
+ self.assertEqual(
+ all_gather_sizes, [expected_all_gather_size] * self.world_size
+ )
+ all_gather_sizes.clear()
+ # Force-reshard the module to check the backward all-gather
+ module.reshard()
+ with patch_all_gather(all_gather):
+ out.sum().backward()
+ if all_gather_sizes:
+ self.assertEqual(len(all_gather_sizes), self.world_size)
+ self.assertEqual(
+ all_gather_sizes, [expected_all_gather_size] * self.world_size
+ )
+ all_gather_sizes.clear()
+
+ # - Check for multiple FSDP parameter groups
+ module = self.init_multi_module()
+ ref_module = copy.deepcopy(module)
+ module = convert_to_float8_training(module, config=float8_linear_config)
+ for submodule in module:
+ fully_shard(submodule)
+ fully_shard(module)
+ expected_all_gather_sizes = (
+ get_expected_all_gather_size(submodule) for submodule in module
+ )
+ with patch_all_gather(all_gather):
+ out = module(local_inp)
+ if all_gather_sizes:
+ self.assertEqual(len(all_gather_sizes), self.world_size * len(module))
+ self.assertEqual(
+ all_gather_sizes,
+ [s for s in expected_all_gather_sizes for _ in range(self.world_size)],
+ )
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_fp32_fp8_single_module_parity(self):
+ """
+ Tests numeric parity for fp32 parameters with fp8 computation with a
+ single module/FSDP communication group.
+ """
+ choices = itertools.product(
+ [False, True],
+ [ScalingType.DYNAMIC, ScalingType.DELAYED],
+ )
+ for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
+ float8_linear_config1 = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=False,
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ float8_linear_config2 = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ module_fp32 = self.init_single_module()
+ ref_module = copy.deepcopy(module_fp32)
+ ref_module = convert_to_float8_training(
+ ref_module,
+ config=float8_linear_config1,
+ )
+ ref_module = ref_module.cuda()
+ module = convert_to_float8_training(
+ module_fp32,
+ config=float8_linear_config2,
+ )
+ fully_shard(module)
+ ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
+ optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
+ local_inp = self.get_local_inp()
+ check_parity_no_mp(
+ self,
+ ref_module,
+ ref_optim,
+ module,
+ optim,
+ local_inp,
+ config=float8_linear_config2,
+ )
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_fp32_fp8_multi_module_parity(self):
+ """
+ Tests numeric parity for fp32 parameters with fp8 computation with
+ multiple modules/FSDP communication groups.
+ """
+ choices = itertools.product(
+ [False, True],
+ [ScalingType.DYNAMIC, ScalingType.DELAYED],
+ )
+ for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
+ float8_linear_config1 = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=False,
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ float8_linear_config2 = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ )
+ module = self.init_multi_module().cuda()
+ ref_module = copy.deepcopy(module)
+ ref_module = convert_to_float8_training(
+ ref_module,
+ config=float8_linear_config1,
+ )
+ module = convert_to_float8_training(
+ module,
+ config=float8_linear_config2,
+ )
+ for submodule in module:
+ fully_shard(submodule)
+ fully_shard(module)
+ ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
+ optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
+ local_inp = self.get_local_inp()
+ check_parity_no_mp(
+ self,
+ ref_module,
+ ref_optim,
+ module,
+ optim,
+ local_inp,
+ config=float8_linear_config2,
+ )
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_bf16_mp_fp8_dynamic_multi_parity(self):
+ """
+ Tests numeric parity for fp32 parameters with FSDP's bf16 mixed
+ precision and fp8 computation with multiple modules/FSDP communication
+ groups. Parameters are all-gathered in bf16 before being cast to fp8.
+ """
+ # NOTE: We cannot test easily with fp8 all-gather because then the scale
+ # is computed using the fp32 sharded parameters, not the bf16 unsharded
+ # parameters, changing the numerics.
+ module = self.init_multi_module()
+ ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16)
+ float8_config = Float8LinearConfig(emulate=True)
+ ref_module_bf16 = convert_to_float8_training(
+ ref_module_bf16,
+ config=float8_config,
+ )
+ ref_module_fp32 = copy.deepcopy(module).cuda()
+ module = convert_to_float8_training(module, config=float8_config)
+ mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
+ for mlp in module:
+ fully_shard(mlp, mp_policy=mp_policy)
+ fully_shard(module, mp_policy=mp_policy)
+ check_parity_bf16_mp(
+ self,
+ ref_module_fp32,
+ ref_module_bf16,
+ torch.optim.Adam(ref_module_fp32.parameters(), lr=1e-2),
+ module,
+ torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True),
+ self.get_local_inp(torch.bfloat16),
+ )
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_delayed_scaling_inplace_update(self):
+ """
+ Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace
+ """
+ module = self.init_single_module()
+ float8_linear_config = Float8LinearConfig(
+ enable_fsdp_float8_all_gather=True,
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ )
+ m_fp8 = convert_to_float8_training(
+ module,
+ config=float8_linear_config,
+ )
+
+ fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach()
+ dummy_mesh = None
+ data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh)
+ self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item())
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py
new file mode 100644
index 000000000..f4ca160fd
--- /dev/null
+++ b/test/float8/test_fsdp_compile.py
@@ -0,0 +1,139 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Test autocast + torch.compile + FSDP + Float8Linear
+"""
+
+import os
+import warnings
+
+import fire
+
+import pytest
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from torchao.float8 import Float8LinearConfig
+from torchao.float8.config import CastConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ sync_float8_amax_and_scale_history,
+)
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+torch.manual_seed(0)
+
+B, M, K, N = 8, 8, 32, 32
+lr = 0.01
+N_ITER = 1
+
+
+def setup(rank, world_size):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def cleanup():
+ dist.destroy_process_group()
+
+
+def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
+ # composability of torch.compile + FSDP + autocast + Float8Linear
+ # as fo 2023-12-30
+
+ # without any changes to the Float8Linear, we get this error:
+ # https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730
+
+ # if we initialize Float8Linear with is_amax_initialized=True and
+ # amax_and_scale_synced=True, we get
+ # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55
+ # to get around this, we can disable amax init
+ config = Float8LinearConfig(
+ enable_amax_init=False,
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+ emulate=emulate,
+ )
+
+ m = nn.Sequential(
+ nn.Linear(K, N, dtype=base_dtype),
+ nn.ReLU(),
+ )
+ convert_to_float8_training(
+ m,
+ config=config,
+ )
+ return m
+
+
+# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
+# and modified
+def fsdp_main(rank, world_size, args):
+ setup(rank, world_size)
+ torch.cuda.set_device(rank)
+
+ (emulate,) = args
+
+ # finally, if we remove the usage of self.bias_dtype, then
+ # things work e2e. Note that FSDP does not support full-graph compile
+ # regardless of float8.
+
+ model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to(
+ rank
+ )
+
+ # To compile FSDP, we need use_orig_params to True
+ model = FSDP(model, use_orig_params=True)
+
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
+ input_local = torch.randn(B, M, K, N, device="cuda")
+ sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)
+
+ model = torch.compile(model)
+
+ for _iter in range(N_ITER):
+ optimizer.zero_grad()
+ with torch.autocast("cuda"):
+ y_local = model(input_local)
+ y_local.sum().backward()
+ sync_float8_func(model)
+ optimizer.step()
+
+ print("done!")
+ cleanup()
+
+
+def run():
+ emulate = False
+ if not torch.cuda.is_available():
+ warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2)
+ emulate = True
+ elif torch.cuda.get_device_capability() < (9, 0):
+ warnings.warn(
+ f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode",
+ stacklevel=2,
+ )
+ emulate = True
+
+ WORLD_SIZE = torch.cuda.device_count()
+ args = (emulate,)
+ mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
+
+
+if __name__ == "__main__":
+ fire.Fire(run)
diff --git a/test/float8/test_fsdp_compile.sh b/test/float8/test_fsdp_compile.sh
new file mode 100755
index 000000000..666136aba
--- /dev/null
+++ b/test/float8/test_fsdp_compile.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# terminate script on first error
+set -e
+if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
+ echo "Skipping test_fsdp_compile.sh because no CUDA devices are available."
+ exit
+fi
+
+# Code to be executed if CUDA devices are available
+NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py
diff --git a/test/float8/test_inference_flows.py b/test/float8/test_inference_flows.py
new file mode 100644
index 000000000..c76a43df0
--- /dev/null
+++ b/test/float8/test_inference_flows.py
@@ -0,0 +1,245 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import copy
+import io
+import random
+import unittest
+
+import pytest
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchao.float8.config import ScalingType
+from torchao.float8.float8_linear_utils import convert_to_float8_training
+from torchao.float8.float8_tensor import Float8Tensor
+from torchao.float8.float8_utils import compute_error
+from torchao.float8.inference import (
+ ActivationCasting,
+ Float8InferenceLinear,
+ QuantConfig,
+ quantize_to_float8,
+)
+
+
+random.seed(0)
+torch.manual_seed(0)
+
+is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
+
+
+class FeedForward(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(4096, 14336, bias=False)
+ self.w3 = nn.Linear(4096, 14336, bias=False)
+ self.w2 = nn.Linear(14336, 4096, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ m.reset_parameters()
+
+
+class TestHPTrainToFP8LinearInference:
+ def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
+ with torch.no_grad():
+ base_output = base_mlp(input_tensor)
+ transformed_output = quantized_mlp(input_tensor)
+
+ # Compute and check SQNR
+ sqnr = compute_error(base_output, transformed_output)
+ assert sqnr.item() > 20, f"SQNR is too low: {sqnr.item()} dB"
+
+ @pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
+ @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+ @unittest.skipIf(
+ not torch.cuda.is_available() or not is_H100,
+ "CUDA not available or on non H100 machine",
+ )
+ def test_dynamic_fp8_mlp(self, compile_backend, dtype):
+ original_mlp = FeedForward().to("cuda", dtype=dtype)
+ original_mlp.reset_parameters()
+
+ dynamic_fp8_mlp = copy.deepcopy(original_mlp)
+
+ quant_config = QuantConfig(ActivationCasting.DYNAMIC)
+ quantize_to_float8(dynamic_fp8_mlp, quant_config)
+
+ batch_size = 4
+ num_tokens = 1024
+ embedding_dim = 4096
+
+ input_tensor = torch.randn(
+ batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype
+ )
+
+ # Compile the models
+ compiled_original_mlp = torch.compile(
+ original_mlp, backend=compile_backend, fullgraph=True
+ )
+ compiled_dynamic_fp8_mlp = torch.compile(
+ dynamic_fp8_mlp, backend=compile_backend, fullgraph=True
+ )
+
+ self.base_test_mlp_transform(
+ compiled_original_mlp, compiled_dynamic_fp8_mlp, input_tensor
+ )
+
+ @pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
+ @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+ @unittest.skipIf(
+ not torch.cuda.is_available() or not is_H100,
+ "CUDA not available or on non H100 machine",
+ )
+ def test_static_fp8_mlp(self, compile_backend, dtype):
+ original_mlp = FeedForward().to("cuda", dtype=dtype)
+ original_mlp.reset_parameters()
+
+ static_fp8_mlp = copy.deepcopy(original_mlp)
+ quant_config = QuantConfig(
+ ActivationCasting.STATIC,
+ static_quantization_scale=torch.tensor(
+ [1.0], device="cuda", dtype=torch.float32
+ ),
+ )
+ quantize_to_float8(static_fp8_mlp, quant_config)
+
+ batch_size = 4
+ num_tokens = 1024
+ embedding_dim = 4096
+
+ input_tensor = torch.randn(
+ batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype
+ )
+
+ # Compile the models
+ compiled_original_mlp = torch.compile(
+ original_mlp, backend=compile_backend, fullgraph=True
+ )
+ compiled_static_fp8_mlp = torch.compile(
+ static_fp8_mlp, backend=compile_backend, fullgraph=True
+ )
+
+ self.base_test_mlp_transform(
+ compiled_original_mlp, compiled_static_fp8_mlp, input_tensor
+ )
+
+ @pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
+ @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+ @unittest.skipIf(
+ not torch.cuda.is_available() or not is_H100,
+ "CUDA not available or on non H100 machine",
+ )
+ def test_weight_only_fp8_mlp(self, compile_backend, dtype):
+ original_mlp = FeedForward().to("cuda", dtype=dtype)
+ original_mlp.reset_parameters()
+
+ static_fp8_mlp = copy.deepcopy(original_mlp)
+ quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY)
+ quantize_to_float8(static_fp8_mlp, quant_config)
+
+ batch_size = 4
+ num_tokens = 1024
+ embedding_dim = 4096
+
+ input_tensor = torch.randn(
+ batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype
+ )
+
+ # Compile the models
+ compiled_original_mlp = torch.compile(
+ original_mlp, backend=compile_backend, fullgraph=True
+ )
+ compiled_static_fp8_mlp = torch.compile(
+ static_fp8_mlp, backend=compile_backend, fullgraph=True
+ )
+
+ self.base_test_mlp_transform(
+ compiled_original_mlp, compiled_static_fp8_mlp, input_tensor
+ )
+
+
+class TestFP8TrainToFP8LinearInference:
+ def train(self, model: nn.Module, dtype: torch.dtype):
+ model.train()
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
+ criterion = nn.MSELoss()
+ target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
+ for _ in range(10):
+ input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
+ optimizer.zero_grad()
+ output = model(input_tensor)
+ loss = criterion(output, target_tensor)
+ loss.backward()
+ optimizer.step()
+ model.eval()
+ return model
+
+ @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+ @unittest.skipIf(
+ not torch.cuda.is_available() or not is_H100,
+ "CUDA not available or on non H100 machine",
+ )
+ def test_fp8_save_and_load(self, dtype: torch.dtype):
+ # Initialize FP8 model
+ fp8_mlp = FeedForward().to("cuda", dtype=torch.float32)
+ fp8_mlp.reset_parameters()
+ convert_to_float8_training(fp8_mlp)
+
+ # Train the model
+ self.train(fp8_mlp, dtype)
+
+ # Generate input tensor and original out
+ input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
+ og_out = fp8_mlp(input_tensor)
+
+ # Save model state dict
+ buffer = io.BytesIO()
+ torch.save(fp8_mlp.state_dict(), buffer)
+
+ # Reset buffer position to the beginning
+ buffer.seek(0)
+
+ # Later on you load the model, will be w/ Float8Linear on meta device
+ with torch.device("meta"):
+ new_fp8_mlp = FeedForward().to(dtype=dtype)
+ convert_to_float8_training(new_fp8_mlp)
+
+ # Load the actual data
+ new_fp8_mlp.load_state_dict(
+ torch.load(buffer, weights_only=True), strict=True, assign=True
+ )
+
+ quant_config = QuantConfig(ActivationCasting.DYNAMIC)
+ quantize_to_float8(new_fp8_mlp, quant_config)
+
+ fp8_mod_count = 0
+ for module in new_fp8_mlp.modules():
+ if isinstance(module, Float8InferenceLinear):
+ assert isinstance(module.weight, Float8Tensor)
+ assert module.weight.requires_grad is False
+ fp8_mod_count += 1
+ assert fp8_mod_count == 3, "Expected 3 FP8 modules, got {}".format(
+ fp8_mod_count
+ )
+
+ new_out = new_fp8_mlp(input_tensor)
+
+ # Assert exact equality
+ assert torch.all(og_out == new_out).item()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py
new file mode 100644
index 000000000..fd724b340
--- /dev/null
+++ b/test/float8/test_numerics_integration.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Tests LLaMa FeedForward numerics with float8
+
+import copy
+from typing import Optional
+
+import pytest
+
+from torchao.utils import TORCH_VERSION_AFTER_2_4
+
+if not TORCH_VERSION_AFTER_2_4:
+ pytest.skip("Unsupported PyTorch version", allow_module_level=True)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_utils import compute_error, IS_ROCM
+
+is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
+
+
+torch.manual_seed(0)
+
+
+# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
+class FeedForward(nn.Module):
+ """
+ FeedForward module
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
+
+ Attributes:
+ w1 (Linear): Linear transformation for the first layer.
+ w2 (Linear): Linear transformation for the second layer.
+ w3 (Linear): Linear transformation for the third layer.
+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+ def init_weights(self, init_std: float):
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
+ for linear in (self.w2, self.w3):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
+
+
+class TestFloat8NumericsIntegrationTest:
+ @pytest.mark.parametrize(
+ "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+ )
+ @pytest.mark.parametrize(
+ "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
+ )
+ @pytest.mark.parametrize(
+ "scaling_type_grad_output",
+ [ScalingType.DELAYED, ScalingType.DYNAMIC],
+ )
+ @pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
+ @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
+ def test_encoder_fw_bw(
+ self,
+ scaling_type_input: ScalingType,
+ scaling_type_weight: ScalingType,
+ scaling_type_grad_output: ScalingType,
+ ):
+ # TODO(later): maybe add float16 back if it becomes important
+ data_dtype = torch.bfloat16
+
+ # LLaMa 3 70B shapes
+ model_ref = (
+ FeedForward(
+ dim=4096,
+ hidden_dim=16384,
+ multiple_of=1024,
+ ffn_dim_multiplier=1.3,
+ )
+ .cuda()
+ .to(data_dtype)
+ )
+
+ # for now just test the encoder to simplify things
+ model_fp8 = copy.deepcopy(model_ref)
+ config = Float8LinearConfig(
+ cast_config_input=CastConfig(scaling_type=scaling_type_input),
+ cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
+ cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
+ )
+ convert_to_float8_training(
+ model_fp8,
+ config=config,
+ )
+
+ lr = 0.01
+ optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr)
+ optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)
+
+ # Note: you need two different inputs to properly test numerics
+ # of delayed scaling, because the first time around the initialization
+ # logic of delayed scaling behaves as dynamic scaling
+ # TODO(future): also make unit tests do this properly
+ shape = (1, 8192, 4096)
+ data1 = torch.randn(*shape, device="cuda", dtype=data_dtype)
+ data2 = torch.randn(*shape, device="cuda", dtype=data_dtype)
+
+ model_ref(data1).sum().backward()
+ # zero out grads without stepping, since we just want to compare grads
+ # of the second datum
+ optim_ref.zero_grad()
+ model_ref_out = model_ref(data2)
+ model_ref_out.sum().backward()
+
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(model_fp8)
+ model_fp8(data1).sum().backward()
+ # zero out grads without stepping, since we just want to compare grads
+ # of the second datum
+ optim_fp8.zero_grad()
+ if linear_requires_sync(config):
+ sync_float8_amax_and_scale_history(model_fp8)
+ model_fp8_out = model_fp8(data2)
+ model_fp8_out.sum().backward()
+
+ out_sqnr = compute_error(model_ref_out, model_fp8_out)
+ assert out_sqnr > 20.0
+
+ ref_name_to_grad = {
+ name: param.grad for name, param in model_ref.named_parameters()
+ }
+
+ grad_sqnr_threshold = 20.0
+
+ for name, param in model_fp8.named_parameters():
+ ref_grad = ref_name_to_grad[name]
+ cur_grad = param.grad
+ sqnr = compute_error(ref_grad, cur_grad)
+ assert sqnr > grad_sqnr_threshold
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/torchao/float8/README.md b/torchao/float8/README.md
new file mode 100644
index 000000000..abab3b9fa
--- /dev/null
+++ b/torchao/float8/README.md
@@ -0,0 +1,159 @@
+# torchao.float8
+
+This is an early version of a library for accelerating training with float8 in native PyTorch
+according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
+The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
+and composable with key systems such as autograd, ```torch.compile``` and distributed.
+With ``torch.compile`` on, initial results show
+throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
+
+:warning: See the [feature tracker](https://github.com/pytorch-labs/torchao.float8/issues/187) for upcoming features.
+
+:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and
+will change rapidly.
+
+# Single GPU User API
+
+We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).
+
+## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`
+
+This is the most accurate recipe as every tensor is scaled dynamically.
+
+```python
+from torchao.float8 import (
+ convert_to_float8_training,
+ precompute_float8_dynamic_scale_for_fsdp,
+)
+
+# create model
+m = Model(...)
+
+# optional: filter modules from being eligible for float8 conversion
+def module_filter_fn(mod: torch.nn.Module, fqn: str):
+ # don't convert the output module
+ if fqn == "output":
+ return False
+ # don't convert linear modules with weight dimensions not divisible by 16
+ if isinstance(mod, torch.nn.Linear):
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
+ return False
+ return True
+
+# convert all `torch.nn.Linear` modules to `Float8Linear`
+convert_to_float8_training(m, module_filter_fn=module_filter_fn)
+
+# optional: use FSDP
+model = FSDP(model, use_orig_params=True)
+
+# optional: enable torch.compile for improved performance
+m = torch.compile(m)
+
+# toy training loop
+for _ in range(N_ITER):
+ optimizer.zero_grad()
+ y = m(x)
+ y.sum().backward()
+ optimizer.step()
+
+ # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
+ # this method is optional but is highly recommended for performance
+ # it calcuclates scales for all parameters in a single all-reduce
+ precompute_float8_dynamic_scale_for_fsdp(model)
+
+```
+
+## float8 linear with delayed scaling
+
+This is theoretically the most performant recipe as it minimizes memory reads.
+
+```python
+from torchao.float8 import (
+ convert_to_float8_training,
+ sync_float8_amax_and_scale_history,
+ ScalingType,
+)
+
+# create model
+m = Model(...)
+
+# optional: configure for compatibility with FSDP. Note that workarounds
+# gated with config.enable_amax_init and
+# config.enable_pre_and_post_forward are needed for
+# autocast + compile + FSDP + float8 to work
+from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig
+config = Float8LinearConfig(
+ enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
+ enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
+ cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
+ cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
+)
+
+# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
+# type
+convert_to_float8_training(
+ m,
+ config=config,
+)
+
+# optional: use FSDP
+model = FSDP(model, use_orig_params=True)
+
+# optional: enable torch.compile for improved performance
+m = torch.compile(m)
+
+# toy training loop
+for _ in range(N_ITER):
+ optimizer.zero_grad()
+ y = m(x)
+ y.sum().backward()
+
+ # specific to float8 with delayed scaling: separate step to sync scales/amaxes
+ # in the future, this may move to a context manager
+ sync_float8_amax_and_scale_history(model)
+
+ optimizer.step()
+```
+
+# Multi GPU User API
+
+We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html),
+such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
+on using `torchao.float8` in a distributed setting.
+
+# Testing
+
+```bash
+# run single-GPU unit tests
+pytest test/float8/test_base.py
+
+# run single-GPU compile tests
+pytest test/float8/test_compile.py
+
+# run single-GPU numerics integration tests
+pytest test/float8/test_numerics_integration.py
+
+# run a two-GPU integration test on FSDP
+./test/float8/test_fsdp.sh
+
+# run integration tests on the DTensor TP/SP integration
+./test/float8/test_dtensor.sh
+
+# run integration tests on the FSDP2 integration
+python test/float8/test_fsdp2/test_fsdp2.py
+
+# run all of these tests
+./test/float8/test_everything.sh
+```
+
+# Benchmarking
+
+```bash
+# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
+./benchmarks/float8/bench_matmul.py
+
+# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
+# make sure to turn on torch.compile to get the best performance
+./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile
+```
diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py
new file mode 100644
index 000000000..56c7b28f7
--- /dev/null
+++ b/torchao/float8/__init__.py
@@ -0,0 +1,46 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+# Lets define a few top level things here
+from torchao.float8.config import (
+ CastConfig,
+ DelayedScalingConfig,
+ Float8GemmConfig,
+ Float8LinearConfig,
+ ScalingType,
+)
+from torchao.float8.float8_linear import Float8Linear
+from torchao.float8.float8_linear_utils import (
+ convert_to_float8_training,
+ linear_requires_sync,
+ sync_float8_amax_and_scale_history,
+)
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ LinearMMConfig,
+ ScaledMMConfig,
+)
+from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
+
+# Needed to load Float8Tensor with weights_only = True
+from torch.serialization import add_safe_globals
+
+add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])
+
+__all__ = [
+ # configuration
+ "DelayedScalingConfig",
+ "ScalingType",
+ "Float8GemmConfig",
+ "Float8LinearConfig",
+ "CastConfig",
+ # top level UX
+ "convert_to_float8_training",
+ "linear_requires_sync",
+ "sync_float8_amax_and_scale_history",
+ "precompute_float8_dynamic_scale_for_fsdp",
+ # note: Float8Tensor and Float8Linear are not public APIs
+]
diff --git a/torchao/float8/config.py b/torchao/float8/config.py
new file mode 100644
index 000000000..5d1bf9f54
--- /dev/null
+++ b/torchao/float8/config.py
@@ -0,0 +1,129 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+import enum
+from dataclasses import dataclass
+
+
+# TODO(future): consider renaming to ScalingType
+class ScalingType(enum.Enum):
+ DELAYED = "delayed"
+ DYNAMIC = "dynamic"
+
+ def short_str(self):
+ if self is ScalingType.DELAYED:
+ return "del"
+ else:
+ assert self is ScalingType.DYNAMIC
+ return "dyn"
+
+
+@dataclass(frozen=True)
+class CastConfig:
+ """
+ Configuration for casting a single tensor to float8
+ """
+
+ scaling_type: ScalingType = ScalingType.DYNAMIC
+
+
+@dataclass(frozen=True)
+class DelayedScalingConfig:
+ """
+ Configuration for delayed scaling.
+
+ Note: for now, `history_len` values must be the same for all layers in the
+ model using delayed scaling.
+
+ TODO(future): serialization for recipes
+ """
+
+ # Controls the history length of amax buffers
+ history_len: int = 16
+
+ # Controls the way to calculate current scale from amax history
+ # TODO(future): add other functions as needed, hardcoded or user defined
+ scale_fn_name: str = "max"
+
+ def __post_init__(self):
+ assert (
+ self.scale_fn_name == "max"
+ ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
+
+
+@dataclass(frozen=True)
+class Float8GemmConfig:
+ """
+ Configuration for a float8 gemm.
+ """
+
+ # If True, fast accumulation in lower precision is used.
+ # Note: this flag is currently a no-op if emulation is turned on.
+ use_fast_accum: bool = False
+
+
+@dataclass(frozen=True)
+class Float8LinearConfig:
+ """
+ Configuration for converting a `torch.nn.Linear` module to float8
+ for training.
+ """
+
+ #
+ # Per-tensor configuration for `input`, `weight`, `grad_output`
+ #
+ cast_config_input: CastConfig = CastConfig()
+ cast_config_weight: CastConfig = CastConfig()
+ cast_config_grad_output: CastConfig = CastConfig()
+
+ #
+ # Per-gemm configuration for gemms calculating `output`, `grad_input` and
+ # `grad_weight`
+ #
+ gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
+ gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
+ gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig()
+
+ #
+ # Per-linear configuration
+ #
+
+ # If True, on the first iteration of Float8Linear the amaxes will be
+ # initialized with the incoming data. As of 2023-12-30, this doesn't work
+ # with autocast + torch.compile + FSDP. Enabling this option is nice for
+ # testing, but this is not necessary for real training jobs.
+ enable_amax_init: bool = True
+
+ # If True, pre-forward and post-forward functions are run. As of 2023-12-30,
+ # this doesn't work with autocast + torch.compile + FSDP. Enabling this
+ # option is useful for safety, but not strictly necessary.
+ enable_pre_and_post_forward: bool = True
+
+ # If True, then uses a tensor subclass for the float8 linear module's weight that
+ # implements pre/post-all-gather methods to do float8 all-gather with FSDP2.
+ enable_fsdp_float8_all_gather: bool = False
+
+ # If True, then prior to performing the fp8 scaled mamtmul we will pad the
+ # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
+ # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
+ # This can cause a memory spike however so we keep this off by default.
+ pad_inner_dim: bool = False
+
+ # If True, emulation is used instead of hardware accelerated gemm
+ emulate: bool = False
+
+ # Configuration for delayed scaling
+ # Note: this is actually applied per-tensor, but only using the same
+ # configuration for all tensors and layers in the model is currently
+ # supported. If in the future we add support for a more fine grained
+ # configuration, this field may move to per-tensor configs.
+ delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()
+
+
+# If True, use 'fnuz' float8 types for calculations.
+# Currently, ROCm only supports fnuz variants.
+# TODO(future PR): move this to Float8LinearConfig
+use_fnuz_dtype = False
diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py
new file mode 100644
index 000000000..ef174b073
--- /dev/null
+++ b/torchao/float8/distributed_utils.py
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import Any
+
+import torch
+
+from fairscale.nn.model_parallel.initialize import get_model_parallel_group
+
+# from float8_tensor import Float8Tensor
+from torchao.float8.float8_tensor import Float8Tensor
+
+# additional differentiable distributed primitives for SP which are not in
+# the Fairscale codebase
+
+
+def _gather_along_first_dim(input_: torch.Tensor):
+ # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67,
+ # but gather along first dim instead of last dim
+ group = get_model_parallel_group()
+
+ # Bypass the function if we are using only 1 GPU.
+ if torch.distributed.get_world_size(group=group) == 1:
+ return input_
+
+ # Size and dimension.
+ first_dim = 0
+ rank = torch.distributed.get_rank(group=group)
+ world_size = torch.distributed.get_world_size(group=group)
+
+ # If the input is a float8 tensor, we need to do the transformation on the
+ # inner tensor and then return a new wrapper.
+ def _transform(t):
+ # tensors must be contiguous for all_gather to work
+ input_contig = t.contiguous()
+
+ tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)]
+ tensor_list[rank] = input_contig
+ torch.distributed.all_gather(tensor_list, input_contig, group=group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=first_dim).contiguous()
+ return output
+
+ if isinstance(input_, Float8Tensor):
+ new_data = input_._data
+ new_data = new_data.view(torch.int8)
+ new_data = _transform(new_data)
+ new_data = new_data.view(input_._data.dtype)
+ output = Float8Tensor(new_data, input_._scale, input_._orig_dtype)
+ else:
+ output = _transform(input_)
+
+ return output
+
+
+def _reduce_scatter(ctx: Any, input_: torch.Tensor):
+ group = get_model_parallel_group()
+ world_size = torch.distributed.get_world_size(group)
+
+ assert input_.shape[0] % world_size == 0
+ output_shape = (input_.shape[0] // world_size, *input_.shape[1:])
+ output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype)
+
+ torch.distributed.reduce_scatter_tensor(output, input_, group=group)
+ return output
+
+
+def _split_along_first_dim(input_: torch.Tensor):
+ # this is needed for testing
+
+ # like fairscale.nn.model_parallel.mappings._split, but
+ # along the first dim instead of last dim
+
+ group = get_model_parallel_group()
+ local_rank = torch.distributed.get_rank(group)
+ world_size = torch.distributed.get_world_size(group)
+
+ assert input_.shape[0] % world_size == 0
+ input_list = torch.split(input_, input_.shape[0] // world_size)
+ return input_list[local_rank]
+
+
+class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather_along_first_dim(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _reduce_scatter(ctx, grad_output)
+
+
+class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_):
+ return _reduce_scatter(ctx, input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _gather_along_first_dim(grad_output)
+
+
+class _AllGatherFwSplitBw(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather_along_first_dim(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split_along_first_dim(grad_output)
diff --git a/torchao/float8/float8_aten_api.py b/torchao/float8/float8_aten_api.py
new file mode 100644
index 000000000..41d5083d6
--- /dev/null
+++ b/torchao/float8/float8_aten_api.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+This file defines the aten functions for float8. Today, all of these functions
+are emulated. In the future, they should be calling NVIDIA's float8 kernels.
+"""
+
+import torch
+
+from torch.library import Library
+
+
+def mm_float8_emulated(
+ m1, # input 1 data
+ s1, # input 1 scale
+ m2, # input 2 data
+ s2, # input 2 scale
+ dtype3, # output dtype
+):
+ # naive implementation: dq -> op -> q
+ m1_fp32 = m1.float() / s1
+ m2_fp32 = m2.float() / s2
+ m3_fp32 = torch.mm(m1_fp32, m2_fp32)
+
+ return m3_fp32.to(dtype3)
+
+
+#
+# ATen op placeholders
+#
+
+# Register the aten level functions we need.
+# These are mostly placeholder and might need to be implemented in c++ as needed
+lib = Library("aten", "FRAGMENT")
+
+lib.define(
+ "mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
+)
+lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
+lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")
+
+
+@torch.library.impl(lib, "mm_float8_emulated", "Meta")
+def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
+ out = torch.mm(m1.float(), m2.float()).to(dtype3)
+ return out
diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py
new file mode 100644
index 000000000..dd85af921
--- /dev/null
+++ b/torchao/float8/float8_linear.py
@@ -0,0 +1,438 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+A simple module swap UX for a float8 version of `torch.nn.Linear`.
+"""
+
+import dataclasses
+import enum
+
+from typing import Optional
+
+import torch
+
+from torchao.float8.config import Float8LinearConfig, ScalingType
+
+from torchao.float8.float8_scaling_utils import (
+ _maybe_initialize_amaxes_scales_for_float8_cast,
+ hp_tensor_to_float8_delayed,
+ hp_tensor_to_float8_dynamic,
+ NoopFwToFloat8E5M2BwDelayed,
+ NoopFwToFloat8E5M2BwDynamic,
+)
+
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ LinearMMConfig,
+ ScaledMMConfig,
+)
+
+from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
+
+from torchao.float8.fsdp_utils import (
+ WeightWithDelayedFloat8CastTensor,
+ WeightWithDynamicFloat8CastTensor,
+)
+
+
+# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
+@torch._dynamo.allow_in_graph
+class manual_float8_matmul(torch.autograd.Function):
+ """
+ Like torch.matmul, but with the arguments in float8
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ input_fp8,
+ weight_fp8_t,
+ ):
+ ctx.save_for_backward(input_fp8, weight_fp8_t)
+ # the reshapes are needed in order to make the shapes compatible with
+ # torch.mm
+ orig_shape = input_fp8.shape
+ input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
+ res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
+ res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
+ return res_bits
+
+ @staticmethod
+ def backward(ctx, grad_output_fp8):
+ input_fp8, weight_fp8_t = ctx.saved_tensors
+
+ # the reshapes are needed in order to make the shapes compatible with
+ # torch.mm
+ grad_output_fp8_orig_shape = grad_output_fp8.shape
+ grad_output_fp8_reshaped = grad_output_fp8.reshape(
+ -1, grad_output_fp8_orig_shape[-1]
+ )
+
+ # calculate grad_input
+ grad_input = torch.mm(
+ grad_output_fp8_reshaped,
+ weight_fp8_t.t(),
+ )
+ grad_input = grad_input.reshape(
+ *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
+ )
+
+ input_fp8_orig_shape = input_fp8.shape
+ input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])
+
+ # calculate grad_weight
+ # Note: the variant below is slightly faster on LLaMa 3 8B pretraining
+ # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
+ grad_weight = torch.mm(
+ grad_output_fp8_reshaped.t(),
+ input_fp8_reshaped,
+ )
+
+ return grad_input, grad_weight.t()
+
+
+class Float8Linear(torch.nn.Linear):
+ """
+ Note: this is **not** a public API and is only intended to be used
+ inside of this repository. Please file an issue if you would benefit
+ from this being a public API.
+
+ A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
+ scales in way friendly to delayed scaling.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Additional arguments on top of `torch.nn.Linear`'s arguments:
+ * `config`: Float8LinearConfig
+ """
+
+ # Amax scales should always be kept as float32.
+ self.always_float32_buffers = set()
+ config = kwargs.pop("config")
+ emulate = config.emulate
+ super().__init__(*args, **kwargs)
+
+ # Defines the scaling behavior of input, weight, grad_output
+ self.scaling_type_input = config.cast_config_input.scaling_type
+ self.scaling_type_weight = config.cast_config_weight.scaling_type
+ self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
+ # Convenience flag to skip code related to delayed scaling
+ self.has_any_delayed_scaling = (
+ self.scaling_type_input is ScalingType.DELAYED
+ or self.scaling_type_weight is ScalingType.DELAYED
+ or self.scaling_type_grad_output is ScalingType.DELAYED
+ )
+
+ self.config = config
+
+ self.create_buffers()
+
+ self.linear_mm_config = LinearMMConfig(
+ # output
+ ScaledMMConfig(
+ emulate,
+ self.config.gemm_config_output.use_fast_accum,
+ False,
+ self.config.pad_inner_dim,
+ ),
+ # grad_input
+ ScaledMMConfig(
+ emulate,
+ self.config.gemm_config_grad_input.use_fast_accum,
+ False,
+ self.config.pad_inner_dim,
+ ),
+ # grad_weight
+ ScaledMMConfig(
+ emulate,
+ self.config.gemm_config_grad_weight.use_fast_accum,
+ False,
+ self.config.pad_inner_dim,
+ ),
+ )
+
+ # Note: is_amax_initialized is not a buffer to avoid data dependent
+ # control flow visible to dynamo
+ # TODO(future PR): add serialization for this flag
+ self.is_amax_initialized = not self.config.enable_amax_init
+
+ # Syncing of amaxes and scales happens outside of this function. This
+ # flag is here to enforce that the user does not forget to do this.
+ self.amax_and_scale_synced = not self.config.enable_amax_init
+
+ # This is needed to properly handle autocast in the amax/scale
+ # update function for torch.float16
+ self.last_seen_input_dtype = None
+
+ # pre_forward and post_forward are currently broken with FSDP
+ # and torch.compile, this option can disable them
+ # Note that when using `self.config.enable_pre_and_post_forward = False`,
+ # it's recommended to also set `self.config.enable_amax_init = False`.
+ # Otherwise, the amax buffer would never be marked as initialized and
+ # would be initialized in every iteration.
+ self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward
+
+ def create_buffers(self):
+ # Default values for history buffers, see above TODO
+ history_len = self.config.delayed_scaling_config.history_len
+ device = self.weight.device
+ # TODO(future PR): dtype values below don't have the other float8
+ # flavors, fix it
+ default_input = torch.finfo(torch.float8_e4m3fn).max
+ default_weight = torch.finfo(torch.float8_e4m3fn).max
+ default_grad_output = torch.finfo(torch.float8_e5m2).max
+
+ # Note: for now, create all the buffers if any are needed, to postpone
+ # the work to make the scale and amax syncing and history calculation
+ # handle a heterogeneous setup. We can do that work later if benchmarks
+ # show it is worth doing.
+ if self.has_any_delayed_scaling:
+ self.register_always_float32_buffer(
+ "fp8_amax_input", torch.tensor([default_input], device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_amax_history_input", torch.zeros(history_len, device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_scale_input", torch.tensor([1.0], device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_amax_weight", torch.tensor([default_weight], device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_amax_history_weight", torch.zeros(history_len, device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_scale_weight", torch.tensor([1.0], device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_amax_grad_output",
+ torch.tensor([default_grad_output], device=device),
+ )
+ self.register_always_float32_buffer(
+ "fp8_amax_history_grad_output", torch.zeros(history_len, device=device)
+ )
+ self.register_always_float32_buffer(
+ "fp8_scale_grad_output", torch.tensor([1.0], device=device)
+ )
+
+ def register_always_float32_buffer(
+ self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
+ ) -> None:
+ self.register_buffer(name=name, tensor=tensor, persistent=persistent)
+ self.always_float32_buffers.add(name)
+
+ def _apply(self, fn, recurse=True):
+ ret = super()._apply(fn, recurse)
+ self.convert_amax_buffer_to_float32()
+ return ret
+
+ def convert_amax_buffer_to_float32(self):
+ for key in self.always_float32_buffers:
+ if self._buffers[key] is not None:
+ self._buffers[key] = self._buffers[key].to(torch.float32)
+
+ def cast_input_to_float8(
+ self, input: torch.Tensor, is_amax_initialized: bool
+ ) -> torch.Tensor:
+ # Duplicate the autocast logic for F.linear, so that the output
+ # of our module has the right original precision
+ if torch.is_autocast_enabled():
+ # For now, hardcode to GPU's autocast dtype
+ # if we need CPU support in the future, we can add it
+ autocast_dtype = torch.get_autocast_gpu_dtype()
+ input = input.to(autocast_dtype)
+
+ if self.scaling_type_input is ScalingType.DELAYED:
+ scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
+ _maybe_initialize_amaxes_scales_for_float8_cast(
+ input,
+ self.fp8_amax_input,
+ self.fp8_amax_history_input,
+ self.fp8_scale_input,
+ scale_fn_name,
+ e4m3_dtype,
+ is_amax_initialized,
+ reduce_amax=True,
+ )
+ input_fp8 = hp_tensor_to_float8_delayed(
+ input,
+ self.fp8_scale_input,
+ e4m3_dtype,
+ self.fp8_amax_input,
+ linear_mm_config=self.linear_mm_config,
+ gemm_input_role=GemmInputRole.INPUT,
+ )
+ else:
+ assert self.scaling_type_input is ScalingType.DYNAMIC
+ input_fp8 = hp_tensor_to_float8_dynamic(
+ input, e4m3_dtype, self.linear_mm_config
+ )
+ return input_fp8
+
+ def cast_weight_to_float8(
+ self, weight: torch.Tensor, is_amax_initialized: bool
+ ) -> torch.Tensor:
+ if self.scaling_type_weight is ScalingType.DELAYED:
+ if isinstance(self.weight, Float8Tensor): # cast by FSDP
+ weight_fp8 = self.weight
+ else:
+ scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
+ _maybe_initialize_amaxes_scales_for_float8_cast(
+ weight,
+ self.fp8_amax_weight,
+ self.fp8_amax_history_weight,
+ self.fp8_scale_weight,
+ scale_fn_name,
+ e4m3_dtype,
+ is_amax_initialized,
+ reduce_amax=False,
+ )
+
+ weight_fp8 = hp_tensor_to_float8_delayed(
+ weight,
+ self.fp8_scale_weight,
+ e4m3_dtype,
+ self.fp8_amax_weight,
+ linear_mm_config=self.linear_mm_config,
+ gemm_input_role=GemmInputRole.WEIGHT,
+ )
+ else:
+ assert self.scaling_type_weight is ScalingType.DYNAMIC
+ if isinstance(self.weight, Float8Tensor): # cast by FSDP
+ weight_fp8 = self.weight
+ else:
+ weight_fp8 = hp_tensor_to_float8_dynamic(
+ self.weight,
+ e4m3_dtype,
+ self.linear_mm_config,
+ gemm_input_role=GemmInputRole.WEIGHT,
+ )
+ return weight_fp8
+
+ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
+ if self.scaling_type_grad_output is ScalingType.DELAYED:
+ scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
+ output = NoopFwToFloat8E5M2BwDelayed.apply(
+ output,
+ self.fp8_amax_grad_output,
+ self.fp8_amax_history_grad_output,
+ self.fp8_scale_grad_output,
+ scale_fn_name,
+ self.is_amax_initialized,
+ self.linear_mm_config,
+ )
+ else:
+ assert self.scaling_type_grad_output is ScalingType.DYNAMIC
+ output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
+ return output
+
+ def float8_pre_forward(self, input):
+ if not self.enable_pre_and_post_forward:
+ return
+ if (
+ self.is_amax_initialized
+ and (not self.amax_and_scale_synced)
+ and torch.is_grad_enabled()
+ ):
+ raise AssertionError(
+ "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
+ )
+ self.last_seen_input_dtype = input.dtype
+
+ def float8_post_forward(self):
+ if not self.enable_pre_and_post_forward:
+ return
+ # Ensure that calling forward again will fail until the user syncs
+ # amaxes and scales
+ self.is_amax_initialized = True
+ self.amax_and_scale_synced = False
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if self.has_any_delayed_scaling:
+ self.float8_pre_forward(input)
+
+ input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
+ weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)
+
+ output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
+
+ # Cast grad_output to float8_e5m2 during backward
+ output = self.cast_output_to_float8_in_bw(output)
+
+ if self.bias is not None:
+ output = output + self.bias.to(output.dtype)
+
+ if self.has_any_delayed_scaling:
+ self.float8_post_forward()
+ return output
+
+ def scaling_repr(self):
+ # add scaling settings without using too many characters
+ # example: "i:del,w:del,go:dyn"
+ return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}"
+
+ def extra_repr(self):
+ s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
+ return s
+
+ @classmethod
+ def from_float(
+ cls,
+ mod,
+ config: Optional[Float8LinearConfig] = None,
+ ):
+ """
+ Create an nn.Linear with fp8 compute from a regular nn.Linear
+
+ Args:
+ mod (torch.nn.Linear): nn.Linear to convert
+ config (Optional[Float8LinearConfig]): configuration for conversion to float8
+ """
+ if config is None:
+ config = Float8LinearConfig()
+ with torch.device("meta"):
+ new_mod = cls(
+ mod.in_features,
+ mod.out_features,
+ bias=False,
+ config=config,
+ )
+ new_mod.weight = mod.weight
+ new_mod.bias = mod.bias
+ # need to create buffers again when moving from meta device to
+ # real device
+ new_mod.create_buffers()
+
+ # If FSDP float8 all-gather is on, wrap the weight in a float8-aware
+ # tensor subclass. This must happen last because:
+ # 1. weight needs to be on the correct device to create the buffers
+ # 2. buffers need to be already created for the delayed scaling version
+ # of the weight wrapper to be initialized
+ if config.enable_fsdp_float8_all_gather:
+ if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC:
+ new_mod.weight = torch.nn.Parameter(
+ WeightWithDynamicFloat8CastTensor(
+ new_mod.weight,
+ new_mod.linear_mm_config,
+ )
+ )
+ else:
+ assert config.cast_config_weight.scaling_type is ScalingType.DELAYED
+ new_mod.weight = torch.nn.Parameter(
+ WeightWithDelayedFloat8CastTensor(
+ new_mod.weight,
+ new_mod.fp8_amax_weight,
+ new_mod.fp8_amax_history_weight,
+ new_mod.fp8_scale_weight,
+ new_mod.linear_mm_config,
+ new_mod.is_amax_initialized,
+ )
+ )
+
+ return new_mod
diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py
new file mode 100644
index 000000000..675ed5ee6
--- /dev/null
+++ b/torchao/float8/float8_linear_utils.py
@@ -0,0 +1,327 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+from typing import Callable, List, Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torchao.float8.config import Float8LinearConfig, ScalingType
+from torchao.float8.float8_linear import Float8Linear
+
+from torchao.float8.float8_utils import (
+ amax_history_to_scale_stack,
+ e4m3_dtype,
+ e5m2_dtype,
+)
+from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
+
+log = logging.getLogger(__name__)
+log.addHandler(logging.NullHandler())
+
+
+def linear_requires_sync(config: Float8LinearConfig):
+ """Returns whether the given linear_type requires sync before forward."""
+ return any(
+ [
+ config.cast_config_input.scaling_type is ScalingType.DELAYED,
+ config.cast_config_weight.scaling_type is ScalingType.DELAYED,
+ config.cast_config_grad_output.scaling_type is ScalingType.DELAYED,
+ ]
+ )
+
+
+def _update_history_stack(
+ new_amax: torch.Tensor, amax_history_stack: torch.Tensor
+) -> torch.Tensor:
+ """
+ Updates `amax_history` (the last N cur_amax values) inplace with the value
+ of `new_amax`.
+
+ Args:
+ new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
+ amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
+ """
+ assert (
+ amax_history_stack.dim() == 2
+ ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}"
+ assert new_amax.size(0) == amax_history_stack.size(
+ 0
+ ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}"
+ new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1)
+ new_amax_history_stack[:, 0] = new_amax.squeeze(-1)
+ amax_history_stack.copy_(new_amax_history_stack)
+
+
+def swap_linear_layers(
+ module: nn.Module,
+ from_float_func: Callable[[nn.Linear], nn.Linear],
+ *,
+ module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
+) -> nn.Module:
+ """
+ Generic function to swap linear layers in a module with a new type of linear layer.
+
+ Note:
+ If applied to a root-level nn.Linear, the module will not be modified in place
+ and returned instead
+
+ Args:
+ module: Module to modify.
+ from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
+ module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
+ that pass the filter function will be swapped. The inputs to the
+ filter function are the module instance, and the FQN.
+
+ Returns:
+ nn.Module: The modified module with swapped linear layers.
+ """
+ if isinstance(module, nn.Linear) and (
+ module_filter_fn is None or module_filter_fn(module, "")
+ ):
+ if len(list(module.children())) > 0:
+ raise AssertionError(
+ f"Does not support a root nn.Linear with children: {module}"
+ )
+ return from_float_func(
+ module,
+ )
+
+ root_module = module
+
+ def post_order_traversal(
+ module: nn.Module,
+ cur_fqn: Optional[str] = None,
+ parent_module: Optional[nn.Module] = None,
+ ):
+ if cur_fqn is None:
+ cur_fqn = ""
+
+ for child_module_name, child_module in module.named_children():
+ if cur_fqn == "":
+ new_fqn = child_module_name
+ else:
+ new_fqn = f"{cur_fqn}.{child_module_name}"
+
+ post_order_traversal(child_module, new_fqn, module)
+
+ if isinstance(module, nn.Linear) and (
+ module_filter_fn is None or module_filter_fn(module, cur_fqn)
+ ):
+ assert (
+ parent_module is not None
+ ), f"Linear root module should return early: {module}"
+ new_linear_module = from_float_func(module)
+ cur_module_name = cur_fqn.split(".")[-1]
+ setattr(parent_module, cur_module_name, new_linear_module)
+
+ post_order_traversal(root_module)
+ return root_module
+
+
+def convert_to_float8_training(
+ module: nn.Module,
+ *,
+ module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
+ config: Float8LinearConfig = None,
+) -> nn.Module:
+ """
+ Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
+
+ Args:
+ module: Module to modify.
+ module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
+ that pass the filter function will be swapped. The inputs to the
+ filter function are the module instance and the FQN.
+ config (Float8LinearConfig): configuration for conversion to float8
+
+ Returns:
+ nn.Module: The modified module with swapped linear layers.
+ """
+ if config is None:
+ config = Float8LinearConfig()
+ from_float = lambda m: Float8Linear.from_float(
+ m,
+ config=config,
+ )
+ return swap_linear_layers(
+ module,
+ from_float,
+ module_filter_fn=module_filter_fn,
+ )
+
+
+def get_float8_layers(model: torch.nn.Module):
+ """Iterates through the model and returns all the Float8Linear layers.
+ Args:
+ model (torch.nn.Module): The model to look for Float8Linear layers in.
+ """
+
+ # Get all fp8 layers and tensors
+ fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)]
+ if not torch._dynamo.is_compiling():
+ for layer in fp8_layers:
+ for buf in layer.buffers():
+ torch._dynamo.mark_static_address(buf, guard=True)
+ return fp8_layers
+
+
+@torch.no_grad()
+def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None:
+ """
+ Manages the float8 amax and scale bookkeeping. In detail, it does the
+ following:
+ 1. in distributed contexts, syncs amax values across workers for activations and gradients
+ 2. adds the `amax` values to history
+ 3. calculates the scales to be used for next iteration
+ 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules
+ to signal that they have been synced
+
+ TODO(future): design the UX for this (context manager, etc)
+
+ PERFORMANCE NOTE:
+ When you can, it is much more efficient to call get_float8_layers once at
+ the beginning of the training loop and pass the result to this function.
+ Because of how this interacts with torch.compile
+
+ Args:
+ model (torch.nn.Module): The model to track amaxes for
+ fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
+ and we loop over all fp8_layers to sync and update amax scale histories.
+ Users can use get_float8_layers to get all fp8 layers.
+ """
+ if fp8_layers is None:
+ fp8_layers = get_float8_layers(model)
+
+ if len(fp8_layers) == 0:
+ log.warn(
+ "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
+ )
+ return
+
+ def inner_func():
+ """Why do we have this inner_function?
+
+ There are two portions of the outer sync_function that cause graph_breaks:
+ 1. The `get_float8_layers` call can cause graph breaks if the user did not pass
+ in the fp8_layers.
+ 2. At the end of syncing all the amaxes and scales we set the attr on the module
+ signaling that we have synced the amaxes and scales and the next forward can be run.
+ # TODO Maybe we should remove this safety check to remove the graph break?
+
+ By having this inner function, we can ensure that although the outer function may cause graph breaks
+ the inner function will not.
+ """
+ # Loop over all fp8 layers and grab the needed tensors
+ fp8_amax_input_tensor_list = [None] * len(fp8_layers)
+ fp8_amax_weight_tensor_list = [None] * len(fp8_layers)
+ fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers)
+
+ fp8_input_amax_history_stack = [None] * len(fp8_layers)
+ fp8_weight_amax_history_stack = [None] * len(fp8_layers)
+ fp8_grad_output_amax_history_stack = [None] * len(fp8_layers)
+
+ x_dtypes = set()
+ scale_fn_recipes = set()
+
+ for idx, child in enumerate(fp8_layers):
+ fp8_amax_input_tensor_list[idx] = child.fp8_amax_input
+ fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight
+ fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output
+
+ fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input
+ fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
+ fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output
+
+ x_dtypes.add(child.last_seen_input_dtype)
+ scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)
+
+ # TODO This way to get the activation dtype is not ideal
+ if len(x_dtypes) != 1:
+ raise ValueError(
+ f"All layers must have the same last seen input_dtype, got {x_dtypes}"
+ )
+ x_dtype = next(iter(x_dtypes))
+
+ if len(scale_fn_recipes) != 1:
+ raise ValueError(
+ f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
+ )
+ scale_fn_recipe = next(iter(scale_fn_recipes))
+
+ assert (
+ len(fp8_amax_input_tensor_list)
+ == len(fp8_amax_weight_tensor_list)
+ == len(fp8_amax_grad_output_tensor_list)
+ ), "Mismatched lengths of amax tensors."
+
+ if dist.is_initialized():
+ all_amax_tensors = torch.cat(
+ fp8_amax_input_tensor_list
+ + fp8_amax_weight_tensor_list
+ + fp8_amax_grad_output_tensor_list
+ )
+ all_reduced_amax_tensor = all_reduce(
+ all_amax_tensors, "MAX", list(range(dist.get_world_size()))
+ )
+ if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
+ all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
+
+ (
+ reduced_fp8_amax_input_tensor,
+ reduced_fp8_amax_weight_tensor,
+ reduced_fp8_amax_grad_output_tensor,
+ ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list))
+
+ for idx, child in enumerate(fp8_layers):
+ child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx])
+ child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx])
+ child.fp8_amax_grad_output.copy_(
+ reduced_fp8_amax_grad_output_tensor[idx]
+ )
+
+ # We create two stacked tensor groups, one for the amax history and one for the current scales
+ fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list)
+ fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list)
+ fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list)
+
+ fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack)
+ fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack)
+ fp8_grad_output_amax_history_stack = torch.vstack(
+ fp8_grad_output_amax_history_stack
+ )
+
+ # Update the history stacks with the new amax values
+ _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack)
+ _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack)
+ _update_history_stack(
+ fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack
+ )
+
+ # Calculate the new scales from the updated history stacks
+ new_input_scales = amax_history_to_scale_stack(
+ fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
+ )
+ new_weight_scales = amax_history_to_scale_stack(
+ fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
+ )
+ new_grad_output_scales = amax_history_to_scale_stack(
+ fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
+ )
+
+ # Iterate through the layers and update the scales
+ for idx, child in enumerate(fp8_layers):
+ child.fp8_scale_input.copy_(new_input_scales[idx])
+ child.fp8_scale_weight.copy_(new_weight_scales[idx])
+ child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx])
+
+ # This allows for the compile to succede on the inner func and fail on the graph breaks
+ # at the beginning and and of syncing
+ inner_func()
+
+ for child in fp8_layers:
+ # Set a flag to signal amaxes/scales are ready
+ child.amax_and_scale_synced = True
diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py
new file mode 100644
index 000000000..d3c3b405b
--- /dev/null
+++ b/torchao/float8/float8_ops.py
@@ -0,0 +1,363 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import Any, Dict, Tuple
+
+import torch
+
+from torchao.float8.float8_python_api import addmm_float8_unwrapped
+from torchao.float8.float8_tensor import choose_scaled_mm_config, Float8Tensor
+from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
+
+from torch.utils._pytree import tree_map
+
+aten = torch.ops.aten
+c10d_functional = torch.ops.c10d_functional
+_c10d_functional = torch.ops._c10d_functional
+FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
+
+
+def implements(aten_ops):
+ """Register aten ops to the float8 op table"""
+
+ def decorator(func):
+ for op in aten_ops:
+ FLOAT8_OPS_TABLE[op] = func
+ return func
+
+ return decorator
+
+
+@implements(
+ [
+ aten.view.default,
+ aten._unsafe_view.default,
+ aten.t.default,
+ aten.as_strided.default,
+ aten.clone.default,
+ aten.detach.default,
+ aten.slice.Tensor,
+ aten.transpose.int,
+ aten.fill_.Scalar,
+ ]
+)
+def float8_desugar_op(aten_op, args, kwargs=None):
+ new_data = aten_op(args[0]._data, *args[1:], **kwargs)
+ return Float8Tensor(
+ new_data,
+ args[0]._scale,
+ args[0]._orig_dtype,
+ args[0]._linear_mm_config,
+ args[0]._gemm_input_role,
+ )
+
+
+@implements([aten.split.Tensor])
+def float8_split(aten_op, args, kwargs=None):
+ new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)
+
+ def make_float8(data):
+ return Float8Tensor(
+ data,
+ args[0]._scale,
+ args[0]._orig_dtype,
+ args[0]._linear_mm_config,
+ args[0]._gemm_input_role,
+ )
+
+ out = map(make_float8, new_data_tensors)
+ return list(out)
+
+
+# Errors cant `cat_cuda float8 e4m3fn`
+@implements([aten.cat.default])
+def float8_cat(aten_op, args, kwargs=None):
+ chunked_tensors: Tuple[Float8Tensor] = args[0]
+
+ orig_dtype = chunked_tensors[0]._orig_dtype
+ scale = chunked_tensors[0]._scale
+ mm_config = chunked_tensors[0]._linear_mm_config
+ fp8_dtype = chunked_tensors[0]._data.dtype
+ gemm_input_role = chunked_tensors[0]._gemm_input_role
+ chunk_data = []
+ for chunk in chunked_tensors:
+ assert isinstance(
+ chunk, Float8Tensor
+ ), "Expecting all chunks to be of type Float8Tensor"
+ assert (
+ chunk._orig_dtype == orig_dtype
+ ), "Expecting all chunks to be of the same dtype"
+ assert (
+ chunk._scale is scale
+ ), "Expecting all chunks to have thee same scale as a result of a split"
+ assert (
+ chunk._linear_mm_config is mm_config
+ ), "Expecting all chunks to have thee same mm config as a result of a split"
+ assert (
+ chunk._data.dtype == fp8_dtype
+ ), "Expecting all chunks to be of the same dtype as a result of a split"
+ assert (
+ chunk._gemm_input_role is gemm_input_role
+ ), "Expecting all chunks to have the same gemm_input_role as a result of a split"
+ chunk_data.append(chunk._data.view(torch.uint8))
+
+ new_data = aten_op(chunk_data, *args[1:], **kwargs)
+ new_data = new_data.view(fp8_dtype)
+ return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role)
+
+
+@implements([aten.sum.dim_IntList])
+def float8_cast_up_op(aten_op, args, kwargs=None):
+ """Be careful with this function, this is a "fallback" op that
+ casts the output of the op to the original precision. And performs the op.
+
+ We currently need this to support the backward for admmm bias.
+ "addmm" -> out
+ "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
+ """
+
+ def unwrap(x):
+ if isinstance(x, Float8Tensor):
+ return x.to_original_precision()
+ return x
+
+ new_args = tree_map(unwrap, args)
+ new_kwargs = tree_map(unwrap, kwargs)
+ return aten_op(*new_args, **new_kwargs)
+
+
+def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
+ a_data = a._data
+ a_scale = a._scale
+ b_data = b._data
+
+ scaled_mm_config = choose_scaled_mm_config(
+ a._gemm_input_role,
+ a._linear_mm_config,
+ b._gemm_input_role,
+ b._linear_mm_config,
+ )
+
+ if scaled_mm_config.pad_inner_dim:
+ assert a._data.size(1) == b._data.size(
+ 0
+ ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
+ a_data = pad_tensor_for_matmul(a_data, dims=1)
+ b_data = pad_tensor_for_matmul(b_data, dims=0)
+
+ if not is_row_major(a_data.stride()):
+ a_data = a_data.contiguous()
+ if is_row_major(b_data.stride()):
+ b_data = b_data.t().contiguous().t()
+ b_scale = b._scale
+ return a_data, a_scale, b_data, b_scale
+
+
+@implements([aten.mm.default, aten.matmul.default])
+def float8_mm(aten_op, args, kwargs=None):
+ a = args[0]
+ b = args[1]
+
+ assert isinstance(a, Float8Tensor) and isinstance(
+ b, Float8Tensor
+ ), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
+ type(a), type(b)
+ )
+ a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
+ output_dtype = a._orig_dtype
+ scaled_mm_config = choose_scaled_mm_config(
+ a._gemm_input_role,
+ a._linear_mm_config,
+ b._gemm_input_role,
+ b._linear_mm_config,
+ )
+ if scaled_mm_config.emulate:
+ return torch.ops.aten.mm_float8_emulated(
+ a._data, a._scale, b._data, b._scale, output_dtype
+ )
+ tensor_out = addmm_float8_unwrapped(
+ a_data,
+ a_scale,
+ b_data,
+ b_scale,
+ output_dtype,
+ output_scale=None,
+ bias=None,
+ use_fast_accum=scaled_mm_config.use_fast_accum,
+ )
+ return tensor_out
+
+
+@implements([aten.addmm.default])
+def float8_addmm(aten_op, args, kwargs=None):
+ assert (
+ isinstance(args[0], torch.Tensor)
+ and isinstance(args[1], Float8Tensor)
+ and isinstance(args[2], Float8Tensor)
+ )
+ bias = args[0]
+ a = args[1]
+ b = args[2]
+ a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
+ output_dtype = a._orig_dtype
+ assert bias.dtype == output_dtype, "bias dtype must match output dtype"
+ scaled_mm_config = choose_scaled_mm_config(
+ a._gemm_input_role,
+ a._linear_mm_config,
+ b._gemm_input_role,
+ b._linear_mm_config,
+ )
+ if scaled_mm_config.emulate:
+ out = torch.ops.aten.mm_float8_emulated(
+ a._data, a._scale, b._data, b._scale, output_dtype
+ )
+ return out + bias
+ tensor_out = addmm_float8_unwrapped(
+ a_data,
+ a_scale,
+ b_data,
+ b_scale,
+ output_dtype,
+ output_scale=None,
+ bias=bias,
+ use_fast_accum=scaled_mm_config.use_fast_accum,
+ )
+ return tensor_out
+
+
+@implements([aten.is_same_size.default])
+def float8_is_same_size(aten_op, args, kwargs=None):
+ return args[0].shape == args[1].shape
+
+
+@implements([aten._to_copy.default])
+def autocast_to_copy(aten_op, args, kwargs=None):
+ """This gets called when running matmul under autocast
+ when the input is a Float8Tensor, presenting as a fp32
+ tensor.
+ """
+ assert isinstance(args[0], Float8Tensor)
+ assert (
+ len(kwargs) == 1 and "dtype" in kwargs
+ ), "Only support dtype kwarg for autocast"
+ assert kwargs["dtype"] in {
+ torch.float16,
+ torch.bfloat16,
+ }, "Only support floating point conversion for autocast w/ Float8Tensor"
+ return Float8Tensor(
+ args[0]._data,
+ args[0]._scale,
+ kwargs["dtype"],
+ args[0]._linear_mm_config,
+ args[0]._gemm_input_role,
+ )
+
+
+@implements(
+ [
+ c10d_functional.all_gather_into_tensor.default,
+ _c10d_functional.all_gather_into_tensor.default,
+ ]
+)
+def allgather_fp8(aten_op, args, kwargs=None):
+ """
+ override funcol with FP8 handling
+ """
+ fp8_input = args[0]
+ assert isinstance(
+ fp8_input, Float8Tensor
+ ), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
+
+ fp8_data = fp8_input._data
+ fp8_data = fp8_data.contiguous()
+ fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
+ return Float8Tensor(
+ fp8_out,
+ fp8_input._scale,
+ fp8_input._orig_dtype,
+ fp8_input._linear_mm_config,
+ fp8_input._gemm_input_role,
+ )
+
+
+@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
+def wait_tensor_fp8(aten_op, args, kwargs=None):
+ fp8_input = args[0]
+ assert isinstance(fp8_input, Float8Tensor)
+
+ fp8_data = fp8_input._data
+ fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
+ return Float8Tensor(
+ fp8_out,
+ fp8_input._scale,
+ fp8_input._orig_dtype,
+ fp8_input._linear_mm_config,
+ fp8_input._gemm_input_role,
+ )
+
+
+@implements([aten.index_put_.default])
+def index_put_fp8(aten_op, args, kwargs=None):
+ fp8_self = args[0]
+ fp8_values = args[2]
+ assert isinstance(fp8_self, Float8Tensor)
+ assert isinstance(fp8_values, Float8Tensor)
+ assert fp8_self._scale == fp8_values._scale
+ assert fp8_self.dtype == fp8_values.dtype
+ assert fp8_self._orig_dtype == fp8_values._orig_dtype
+
+ fp8_data = fp8_self._data
+ fp8_values_data = fp8_values._data
+ fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
+ return Float8Tensor(
+ fp8_out,
+ fp8_self._scale,
+ fp8_self._orig_dtype,
+ fp8_self._linear_mm_config,
+ fp8_self._gemm_input_role,
+ )
+
+
+@implements([aten.copy_.default])
+def copy_fp8(aten_op, args, kwargs=None):
+ # For a copy op with Float8Tensors involved, only the following combinations are allowed:
+ # 1. self is a high precision (hp) tensor, src is a Float8Tensor:
+ # in this case src is upcasted and unscaled to go into the hp tensor
+ # 2. self and src are Float8Tensors:
+ # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
+ # Every other combination is banned as the semantics are not well defined
+
+ self = args[0]
+ src = args[1]
+
+ if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
+ src_hp = src.to_original_precision()
+ return aten_op(self, src_hp, *args[2:], **kwargs)
+ elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
+ assert (
+ self._orig_dtype == src._orig_dtype
+ ), "Expecting both Float8Tensors to be of the same dtype"
+ assert (
+ self._scale == src._scale
+ ), "Expecting both Float8Tensors to have thee same scale"
+ assert (
+ self._linear_mm_config == src._linear_mm_config
+ ), "Expecting both Float8Tensors to have thee same mm config"
+ assert (
+ self._data.dtype == src._data.dtype
+ ), "Expecting both Float8Tensors to be of the same dtypet"
+ assert (
+ self._gemm_input_role == src._gemm_input_role
+ ), "Expecting both Float8Tensors to have the same gemm_input_role"
+ fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
+ return Float8Tensor(
+ fp8_out,
+ self._scale,
+ self._orig_dtype,
+ self._linear_mm_config,
+ self._gemm_input_role,
+ )
+ else:
+ raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")
diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py
new file mode 100644
index 000000000..16e270574
--- /dev/null
+++ b/torchao/float8/float8_python_api.py
@@ -0,0 +1,64 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+This file defines the Python functions for float8 which expect inputs
+of class `Float8Tensor`. This is a thin wrapper on top of the aten API
+to simplify the product code.
+"""
+
+from typing import Optional
+
+import torchao.float8.float8_aten_api # noqa
+
+import torch
+
+
+# [Note] Usage of scales
+# The meaning of scale in this library can be found in the definition of the Float8Tensor
+# Cublas defines scale to always mean a multiplicative factor for the respective matrices
+# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale
+# For output going from fp32 -> fp8 we multiply by the scale
+def addmm_float8_unwrapped(
+ a_data: torch.Tensor,
+ a_scale: torch.Tensor,
+ b_data: torch.Tensor,
+ b_scale: torch.tensor,
+ output_dtype: torch.dtype,
+ output_scale: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ use_fast_accum: bool = False,
+) -> torch.Tensor:
+ """
+ This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
+ as inputs. This is used to standardize the logic between subclassed and non subclassed
+ versions of the linear module.
+ """
+ a_inverse_scale = a_scale.reciprocal()
+ b_inverse_scale = b_scale.reciprocal()
+ if output_dtype == torch.float32 and bias is not None:
+ # Bias is not supported by _scaled_mm when output is fp32
+ output = torch._scaled_mm(
+ a_data,
+ b_data,
+ scale_a=a_inverse_scale,
+ scale_b=b_inverse_scale,
+ scale_result=output_scale,
+ out_dtype=output_dtype,
+ use_fast_accum=use_fast_accum,
+ )
+ output += bias
+ return output
+ output = torch._scaled_mm(
+ a_data,
+ b_data,
+ scale_a=a_inverse_scale,
+ scale_b=b_inverse_scale,
+ bias=bias,
+ scale_result=output_scale,
+ out_dtype=output_dtype,
+ use_fast_accum=use_fast_accum,
+ )
+ return output
diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py
new file mode 100644
index 000000000..bbf140eff
--- /dev/null
+++ b/torchao/float8/float8_scaling_utils.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utilities for scaling high precision tensors to float8.
+"""
+
+from typing import Optional
+
+import torch
+
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+ ScaledMMConfig,
+ tensor_already_casted_to_fp8,
+)
+
+from torchao.float8.float8_utils import (
+ amax_history_to_scale,
+ e4m3_dtype,
+ e5m2_dtype,
+ tensor_to_amax,
+ tensor_to_scale,
+)
+
+
+def hp_tensor_to_float8_dynamic(
+ hp_tensor: torch.Tensor,
+ float8_dtype: torch.dtype,
+ linear_mm_config: LinearMMConfig,
+ reduce_amax: bool = False,
+ gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
+) -> Float8Tensor:
+ """
+ Given a high precision tensor `hp_tensor`,
+ scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
+
+ Args:
+ hp_tensor: the tensor to convert
+ float8_dtype: the float8 dtype to use
+ linear_mm_config: Defines the configuration for the scaled_mm for
+ the 3 fwd/bwd gemms of linear
+ reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
+ gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
+ the 3 fwd/bwd gemms of linear
+ """
+ if tensor_already_casted_to_fp8(hp_tensor):
+ return hp_tensor
+ scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
+ return hp_tensor_and_scale_to_float8(
+ hp_tensor,
+ scale,
+ float8_dtype,
+ linear_mm_config,
+ gemm_input_role,
+ )
+
+
+def hp_tensor_to_float8_delayed(
+ hp_tensor: torch.Tensor,
+ s: torch.Tensor,
+ float8_dtype: torch.dtype,
+ amax_buffer: torch.Tensor,
+ linear_mm_config: Optional[LinearMMConfig] = None,
+ gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
+) -> Float8Tensor:
+ """
+ Given a high precision tensor `hp_tensor` and relevant metadata, scales it using
+ delayed scaling and returns a `Float8Tensor` of the result. Specifically:
+ 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace
+ 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor
+
+ Args:
+ hp_tensor: the tensor to convert
+ s: the scale to use to convert the tensor
+ float8_dtype: the float8 dtype to use
+ amax_buffer: the buffer to modify inplace with max(abs(hp_tensor))
+ linear_mm_config: Defines the configuration for the scaled_mm for
+ the 3 fwd/bwd gemms of linear
+ gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
+ the 3 fwd/bwd gemms of linear
+ """
+ amax_buffer.fill_(tensor_to_amax(hp_tensor))
+ return hp_tensor_and_scale_to_float8(
+ hp_tensor,
+ s,
+ float8_dtype,
+ linear_mm_config,
+ gemm_input_role,
+ )
+
+
+def _maybe_initialize_amaxes_scales_for_float8_cast(
+ x,
+ cur_amax,
+ amax_history,
+ scale,
+ scale_fn_name,
+ float8_dtype,
+ is_initialized,
+ reduce_amax,
+):
+ """
+ If x is about to be cast to `float8` and the amax buffers are not initialized,
+ initializes them inplace.
+ """
+ if is_initialized:
+ return
+ with torch.no_grad():
+ # Note: we need to enable distributed reduction here in order
+ # to match numerics between single GPU and multi GPU code for
+ # activations and gradients
+ new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
+ cur_amax.fill_(new_amax)
+ amax_history[0] = new_amax
+ new_scale = amax_history_to_scale(
+ amax_history, float8_dtype, x.dtype, scale_fn_name
+ )
+ scale.copy_(new_scale)
+
+
+@torch._dynamo.allow_in_graph
+class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function):
+ """
+ Forward: no-op
+ Backward: convert to float8_e5m2 with delayed scaling, initialize if needed
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ tensor,
+ fp8_amax_grad_output,
+ fp8_amax_history_grad_output,
+ fp8_scale_grad_output,
+ scale_fn_name,
+ is_amax_initialized,
+ linear_mm_config: LinearMMConfig,
+ ):
+ ctx.save_for_backward(
+ fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output
+ )
+ ctx.scale_fn_name = scale_fn_name
+ ctx.is_amax_initialized = is_amax_initialized
+ ctx.linear_mm_config = linear_mm_config
+ return tensor
+
+ @staticmethod
+ def backward(ctx, go):
+ (
+ fp8_amax_grad_output,
+ fp8_amax_history_grad_output,
+ fp8_scale_grad_output,
+ ) = ctx.saved_tensors
+ scale_fn_name = ctx.scale_fn_name
+ is_amax_initialized = ctx.is_amax_initialized
+
+ _maybe_initialize_amaxes_scales_for_float8_cast(
+ go,
+ fp8_amax_grad_output,
+ fp8_amax_history_grad_output,
+ fp8_scale_grad_output,
+ scale_fn_name,
+ e5m2_dtype,
+ is_amax_initialized,
+ reduce_amax=True,
+ )
+
+ fp8_amax_grad_output.fill_(tensor_to_amax(go))
+
+ res = hp_tensor_and_scale_to_float8(
+ go,
+ fp8_scale_grad_output,
+ e5m2_dtype,
+ ctx.linear_mm_config,
+ GemmInputRole.GRAD_OUTPUT,
+ )
+ empty_grads = None, None, None, None, None, None
+ return res, *empty_grads
+
+
+@torch._dynamo.allow_in_graph
+class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function):
+ """
+ Forward: no-op
+ Backward: convert to float8_e5m2 with dynamic scaling
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ tensor,
+ linear_mm_config: LinearMMConfig,
+ ):
+ ctx.linear_mm_config = linear_mm_config
+ return tensor
+
+ @staticmethod
+ def backward(ctx, gradY):
+ if tensor_already_casted_to_fp8(gradY):
+ return gradY, None
+ gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
+ fp8_tensor = hp_tensor_and_scale_to_float8(
+ gradY,
+ gradY_scale,
+ e5m2_dtype,
+ ctx.linear_mm_config,
+ GemmInputRole.GRAD_OUTPUT,
+ )
+ return fp8_tensor, None
diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py
new file mode 100644
index 000000000..a858408fe
--- /dev/null
+++ b/torchao/float8/float8_tensor.py
@@ -0,0 +1,363 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+import enum
+from collections import namedtuple
+from typing import Dict, Optional
+
+import torch
+
+import torch.distributed._functional_collectives as funcol
+from torchao.float8.float8_utils import (
+ e4m3_dtype,
+ tensor_to_amax,
+ to_fp8_saturated,
+)
+from torch.distributed._tensor import DTensor
+
+aten = torch.ops.aten
+
+#
+# A note on configuration of float8 logic in a linear
+# TODO(future): move all the configs to separate file
+# TODO(future): change this to input/weight/grad_output notation,
+# can be separate PR because none of this is user facing
+#
+# There are three gemms in a forward + backward of a Linear layer:
+#
+# 1. input @ weight_t = output (forward pass)
+# 2. grad_output @ weight = grad_input (backward pass)
+# 3. input_t @ grad_output = grad_weight (backward pass)
+#
+# In the formulas above, there are:
+# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t).
+# - Note that grad_output_t is implied because of memory format requirements
+# of float8 gemms
+# B. three output tensors (output, grad_input, grad_weight)
+#
+# We want each input tensor, gemm, and output tensor to be configurable.
+# The state of this configuration today is:
+#
+# i. pairs of input tensors (non-t and t variants) have their scaling
+# configurable via the scaling_type_* arguments to Float8Linear
+# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
+# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
+# to configure all three gemms, also not user facing
+
+
+# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass.
+# emulate: whether to emulate the matmuls in fp32
+# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
+# fp8_output: whether to output the result of the scaled_mm in fp8
+# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
+ScaledMMConfig = namedtuple(
+ "ScaledMMConfig",
+ ["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
+ defaults=[False, False, False, False],
+)
+
+# The object below is not user facing and exists for convenience,
+# to allow Float8Tensor to use
+# the right config based on which gemm from gemms with outputs
+# `output`, `grad_input`, `grad_weight` is
+# being called.
+LinearMMConfig = namedtuple(
+ "LinearMMConfig",
+ ["output", "grad_input", "grad_weight"],
+ defaults=[
+ ScaledMMConfig(False, True, False, False),
+ ScaledMMConfig(False, False, False, False),
+ ScaledMMConfig(False, False, False, False),
+ ],
+)
+
+
+class GemmInputRole(enum.Enum):
+ """
+ Given a Float8Tensor, the enum below describes the expected role of this
+ tensor in the three gemms present in the fw + bw pass of a Linear layer.
+ This is used to choose the right config for a float8 gemm when the
+ gemm is performed.
+ """
+
+ INPUT = "input"
+ WEIGHT = "weight"
+ GRAD_OUTPUT = "grad_output"
+
+
+# choose which scaled_mm_config to use based on gemm inputs
+def choose_scaled_mm_config(
+ a_role: GemmInputRole,
+ a_linear_mm_config: LinearMMConfig,
+ b_role: GemmInputRole,
+ b_linear_mm_config: LinearMMConfig,
+):
+ if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT:
+ assert (
+ a_linear_mm_config.output == b_linear_mm_config.output
+ ), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}"
+ return a_linear_mm_config.output
+ elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT:
+ assert (
+ a_linear_mm_config.grad_input == b_linear_mm_config.grad_input
+ ), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}"
+ return a_linear_mm_config.grad_input
+ elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT:
+ assert (
+ a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight
+ ), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}"
+ return a_linear_mm_config.grad_weight
+ else:
+ raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")
+
+
+def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
+ """
+ Check if the tensor is already casted to fp8
+ """
+ if isinstance(tensor, Float8Tensor):
+ return True
+ elif isinstance(tensor, DTensor):
+ # TODO: shall we stick to public API and directly use tensor.to_local() here?
+ return tensor_already_casted_to_fp8(tensor._local_tensor)
+ elif isinstance(tensor, funcol.AsyncCollectiveTensor):
+ return tensor_already_casted_to_fp8(tensor.elem)
+
+ return False
+
+
+@torch._dynamo.allow_in_graph
+class _ToFloat8ConstrFunc(torch.autograd.Function):
+ """
+ A differentiable conversion to fp8.
+ * forward: convert from high precision to float8
+ * backward: pass the gradient without changes
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ tensor: torch.Tensor,
+ scale: torch.Tensor,
+ float8_dtype=e4m3_dtype,
+ linear_mm_config: Optional[LinearMMConfig] = None,
+ gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
+ ):
+ """
+ This function will apply the scaling, and then convert to a Float8Tensor
+
+ Note:
+ We will call this function with a DTensor subclass. Ideally this would be an aten OP
+ that DTensor could overload to ensure proper semantics. There are some techincal issues
+ with that composing with FakeTensor, so we special case here.
+
+ DTensor Invariant: DTensor must always be the outer most tensor subclass
+ """
+ tensor_scaled = tensor * scale
+ bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
+
+ if isinstance(bits_fp8, DTensor):
+ assert isinstance(
+ scale, DTensor
+ ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
+ bits_mesh = bits_fp8.device_mesh
+ bits_placements = bits_fp8.placements
+ local_bits = bits_fp8.to_local()
+ local_scale = scale.to_local()
+ inner_float8_tensor = Float8Tensor(
+ local_bits,
+ local_scale,
+ tensor.dtype,
+ linear_mm_config=linear_mm_config,
+ gemm_input_role=gemm_input_role,
+ )
+ return DTensor.from_local(
+ inner_float8_tensor,
+ bits_mesh,
+ bits_placements,
+ run_check=False,
+ shape=bits_fp8.size(),
+ stride=bits_fp8.stride(),
+ )
+
+ return Float8Tensor(
+ bits_fp8,
+ scale,
+ tensor.dtype,
+ linear_mm_config=linear_mm_config,
+ gemm_input_role=gemm_input_role,
+ )
+
+ @staticmethod
+ def backward(ctx, g):
+ return g, None, None, None, None, None
+
+
+@torch._dynamo.allow_in_graph
+class _FromFloat8ConstrFunc(torch.autograd.Function):
+ """
+ A differentiable conversion from fp8.
+ * forward: convert from float8 to high precision
+ * backward: pass the gradient without changes
+ """
+
+ @staticmethod
+ def forward(ctx, tensor):
+ return tensor._data.to(tensor._orig_dtype) / tensor._scale
+
+ @staticmethod
+ def backward(ctx, g):
+ return g, None, None
+
+
+def hp_tensor_and_scale_to_float8(
+ hp_tensor: torch.Tensor,
+ s: torch.Tensor,
+ float8_dtype=e4m3_dtype,
+ linear_mm_config: Optional[LinearMMConfig] = None,
+ gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
+):
+ """
+ Given a high precision tensor `hp_tensor` and a precalculated scale `s`,
+ scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result.
+
+ Autograd-aware, the derivative is pass-through.
+ DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor).
+
+ Args:
+ hp_tensor: the tensor to convert
+ s: the scale to use to convert the tensor
+ float8_dtype: the float8 dtype to use
+ linear_mm_config: Defines the configuration for the scaled_mm for
+ the 3 fwd/bwd gemms of linear
+ gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
+ the 3 fwd/bwd gemms of linear
+ """
+ return _ToFloat8ConstrFunc.apply(
+ hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role
+ )
+
+
+class Float8Tensor(torch.Tensor):
+ """
+ Note: this is **not** a public API and is only intended to be used
+ inside of this repository. Please file an issue if you would benefit
+ from this being a public API.
+
+ A Python-only Float8 tensor subclass. Contains:
+ * `_data`: the underlying e4m3 or e5m2 data
+ * `_scale`: the scale used to scale the original fp32 tensor. We multiply
+ by scale to go from fp32 range to fp8 range, and divide by scale to go
+ from fp8 range to fp32 range.
+ * `_orig_dtype`: the original dtype of the tensor used to create this
+ tensor.
+ * `_emulate`: if true using fp32 emulation for the matmuls, helpful
+ if you don't have access to h100 hardware.
+
+ Intended usage of this abstraction:
+ 1. to bundle raw data + fp8 metadata together for easy passing through
+ Python PyTorch systems.
+ 2. Float8-aware user code can use the private fields on these tensors
+ to call into float8 operations.
+ 3. Float8-agnostic user code can use these tensors as is - they will
+ convert to original precision in `__torch_dispatch__`.
+ """
+
+ _data: torch.Tensor
+ _scale: torch.Tensor
+ _orig_dtype: torch.dtype
+ _linear_mm_config: LinearMMConfig
+ __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"]
+
+ def __new__(
+ cls,
+ data: torch.Tensor,
+ scale: torch.Tensor,
+ orig_dtype: torch.dtype,
+ linear_mm_config: Optional[LinearMMConfig],
+ gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
+ ):
+ assert (
+ scale.numel() == 1
+ ), "Scale should contain a single value, but got: {} elements".format(
+ scale.numel()
+ )
+
+ self = torch.Tensor._make_wrapper_subclass(
+ cls,
+ data.size(),
+ strides=data.stride(),
+ storage_offset=data.storage_offset(),
+ dtype=orig_dtype,
+ layout=data.layout,
+ requires_grad=data.requires_grad,
+ device=data.device,
+ )
+ self._data = data
+ self._scale = scale
+ self._orig_dtype = orig_dtype
+ self._linear_mm_config = (
+ linear_mm_config if linear_mm_config is not None else LinearMMConfig()
+ )
+ self._gemm_input_role = gemm_input_role
+
+ return self
+
+ def __repr__(self):
+ return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}"
+
+ def __tensor_flatten__(self):
+ ctx = {
+ "_orig_dtype": self._orig_dtype,
+ "_linear_mm_config": self._linear_mm_config,
+ "_gemm_input_role": self._gemm_input_role,
+ }
+ return ["_data", "_scale"], ctx
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
+ assert len(inner_tensors) == 2
+ return Float8Tensor(
+ inner_tensors["_data"],
+ inner_tensors["_scale"],
+ metadata["_orig_dtype"],
+ metadata["_linear_mm_config"],
+ metadata["_gemm_input_role"],
+ )
+
+ def to_original_precision(self):
+ return _FromFloat8ConstrFunc.apply(self)
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs=None):
+ # 1. tracing through __torch_function__ logic is not supported yet in
+ # PT2.0, so we explicitly disallow it here for callsites from user code.
+ # 2. We do need to handle a couple of ops in order for
+ # TorchDynamo tracing to succeed.
+
+ # Lazy import to avoid circular dependency
+ from torchao.float8.float8_ops import FLOAT8_OPS_TABLE
+
+ # All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
+ # And don't support mixed tensor subclasses. This will trigger the handler for
+ # the next type in the dispatch list
+ def allowed_subclasses(type):
+ return (
+ issubclass(cls, type)
+ or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
+ or issubclass(
+ torch._subclasses.functional_tensor.FunctionalTensor, type
+ )
+ )
+
+ if not all(allowed_subclasses(t) for t in types):
+ return NotImplemented
+
+ if func in FLOAT8_OPS_TABLE:
+ return FLOAT8_OPS_TABLE[func](func, args, kwargs)
+ raise NotImplementedError(f"attempting to run {func}, this is not supported")
+
+ # Do not force the Float8Tensor type on the returned tensor
+ __torch_function__ = torch._C._disabled_torch_function_impl
diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py
new file mode 100644
index 000000000..affec2a76
--- /dev/null
+++ b/torchao/float8/float8_tensor_parallel.py
@@ -0,0 +1,235 @@
+import torch
+import torch.nn as nn
+from torchao.float8.config import ScalingType
+from torchao.float8.float8_scaling_utils import (
+ hp_tensor_to_float8_dynamic,
+ NoopFwToFloat8E5M2BwDynamic,
+)
+from torchao.float8.float8_tensor import GemmInputRole
+from torchao.float8.float8_utils import e4m3_dtype
+from torch.distributed._tensor import DTensor
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ PrepareModuleInput,
+ RowwiseParallel,
+)
+
+# subclass the ColwiseParallel and RowwiseParallel classes
+# to add the float8 support
+# The parameter sharding stays the same as the core
+# ColwiseParallel and RowwiseParallel, the only difference
+# here is that in input/output handling we do casting after
+# creating the DTensor.
+
+# NOTE: This only works and tested with the dynamic scaling
+
+
+def _float8_linear_supports_float8_allgather(m):
+ # TODO(future): add support for delayed scaling for activations
+ # and gradients
+ return (
+ m.scaling_type_input == ScalingType.DYNAMIC
+ and m.scaling_type_grad_output == ScalingType.DYNAMIC
+ )
+
+
+class Float8ColwiseParallel(ColwiseParallel):
+ @staticmethod
+ def _prepare_input_fn(
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
+ ):
+ # annotate module input placements/sharding with input_layouts
+ input_tensor = inputs[0]
+ if not isinstance(input_tensor, DTensor):
+ input_tensor = DTensor.from_local(
+ input_tensor, device_mesh, input_layouts, run_check=False
+ )
+
+ input_tensor = hp_tensor_to_float8_dynamic(
+ input_tensor,
+ e4m3_dtype,
+ mod.linear_mm_config,
+ gemm_input_role=GemmInputRole.INPUT,
+ ) # DTensor(Float8Tensor)
+
+ # transform the input layouts to the desired layouts of ColwiseParallel
+ if input_layouts != desired_input_layouts:
+ input_tensor = input_tensor.redistribute(
+ placements=desired_input_layouts, async_op=True
+ )
+ return input_tensor
+
+ @staticmethod
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+ # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
+ if outputs.placements != output_layouts:
+ outputs = outputs.redistribute(
+ placements=output_layouts, async_op=True
+ ) # DTensor(torch.Tensor)
+
+ # fwd noop bwd cast to DTensor(Float8Tensor)
+ outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config)
+
+ # back to local tensor
+ return outputs.to_local() if use_local_output else outputs
+
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+ from torchao.float8.float8_linear import Float8Linear
+
+ if not isinstance(module, Float8Linear):
+ raise ValueError(
+ f"Expecting module to be Float8Linear but found {type(module)}"
+ )
+ elif isinstance(
+ module, Float8Linear
+ ) and not _float8_linear_supports_float8_allgather(module):
+ raise AssertionError("unsupported")
+
+ return super()._apply(module, device_mesh)
+
+
+class Float8RowwiseParallel(RowwiseParallel):
+ @staticmethod
+ def _prepare_input_fn(
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
+ ):
+ input_tensor = inputs[0]
+ if not isinstance(input_tensor, DTensor):
+ input_tensor = DTensor.from_local(
+ input_tensor, device_mesh, input_layouts, run_check=False
+ )
+
+ input_tensor = hp_tensor_to_float8_dynamic(
+ input_tensor,
+ e4m3_dtype,
+ mod.linear_mm_config,
+ gemm_input_role=GemmInputRole.INPUT,
+ ) # DTensor(Float8Tensor)
+
+ if input_layouts != desired_input_layouts:
+ input_tensor = input_tensor.redistribute(
+ placements=desired_input_layouts, async_op=True
+ )
+ return input_tensor
+
+ @staticmethod
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+ # Rowwise sharding produces partial output, depending on output layouts:
+ # 1. to replicate -> allreduce
+ # 2. to shard -> reduce_scatter
+ if outputs.placements != output_layouts:
+ outputs = outputs.redistribute(placements=output_layouts, async_op=True)
+
+ # fwd noop bwd cast to DTensor(Float8Tensor)
+ outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config)
+
+ # back to local tensor if use_local_output is True
+ return outputs.to_local() if use_local_output else outputs
+
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+ from torchao.float8.float8_linear import Float8Linear
+
+ if not isinstance(module, Float8Linear):
+ raise ValueError(
+ f"Expecting module to be Float8Linear but found {type(module)}"
+ )
+ elif isinstance(
+ module, Float8Linear
+ ) and not _float8_linear_supports_float8_allgather(module):
+ raise AssertionError("unsupported")
+
+ return super()._apply(module, device_mesh)
+
+
+class PrepareFloat8ModuleInput(PrepareModuleInput):
+ # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that
+ # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
+ # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
+ # so that if there are multiple float8 users of the input activation, we perform fp8 allgather
+ # only once.
+ # FP8 Args:
+ # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
+ # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
+ # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
+ # for the float8 cast. If not specified, we will search for the Float8Linear in the submodules
+ # and use the forward config from that module, in this case all module's forward config must be
+ # the same.
+
+ def __init__(
+ self,
+ *,
+ input_layouts=None,
+ desired_input_layouts=None,
+ input_kwarg_layouts=None,
+ desired_input_kwarg_layouts=None,
+ use_local_output=False,
+ float8_dtype=torch.float8_e4m3fn,
+ fwd_config_submodule_fqn=None,
+ ):
+ super().__init__(
+ input_layouts=input_layouts,
+ desired_input_layouts=desired_input_layouts,
+ input_kwarg_layouts=input_kwarg_layouts,
+ desired_input_kwarg_layouts=desired_input_kwarg_layouts,
+ use_local_output=use_local_output,
+ )
+
+ # fp8 specific fields
+ self.float8_dtype = float8_dtype
+ self.linear_mm_config = None
+ self.fwd_config_submodule_fqn = fwd_config_submodule_fqn
+
+ if self.float8_dtype != torch.float8_e4m3fn:
+ raise NotImplementedError(
+ "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now"
+ )
+
+ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
+ if input_layout is not None:
+ if isinstance(input, DTensor):
+ # TODO: re-enable the check once we fix the compile path
+ # assert inp.placements[0] == input_layout
+ dt_inp = input
+ else:
+ assert isinstance(
+ input, torch.Tensor
+ ), "expecting input to be a torch.Tensor!"
+ dt_inp = DTensor.from_local(
+ input, mesh, (input_layout,), run_check=False
+ )
+
+ dt_inp = hp_tensor_to_float8_dynamic(
+ dt_inp,
+ e4m3_dtype,
+ self.linear_mm_config,
+ gemm_input_role=GemmInputRole.INPUT,
+ ) # DTensor(Float8Tensor)
+ if desired_layout is not None and input_layout != desired_layout:
+ dt_inp = dt_inp.redistribute(placements=(desired_layout,))
+
+ return dt_inp.to_local() if self.use_local_output else dt_inp
+ else:
+ return input
+
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+ from torchao.float8.float8_linear import Float8Linear
+
+ if self.fwd_config_submodule_fqn is not None:
+ fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
+ assert isinstance(fwd_linear, Float8Linear)
+ self.linear_mm_config = fwd_linear.linear_mm_config
+ else:
+ # search for ScaledMM configs for all the submodules and make sure they are the same
+ for mod in module.modules():
+ if isinstance(mod, Float8Linear):
+ if self.linear_mm_config is None:
+ self.linear_mm_config = mod.linear_mm_config
+ else:
+ assert (
+ self.linear_mm_config == mod.linear_mm_config
+ ), "All the Float8Linear modules should have same linear_mm_config!"
+
+ assert self.linear_mm_config is not None
+ super()._apply(module, device_mesh)
+ return module
diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py
new file mode 100644
index 000000000..1d6c69d17
--- /dev/null
+++ b/torchao/float8/float8_utils.py
@@ -0,0 +1,247 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable, Literal, Tuple, Union
+
+import torchao.float8.config as config
+
+import torch
+import torch.distributed as dist
+
+# Helpful visualizer for debugging (only supports fp32):
+# https://www.h-schmidt.net/FloatConverter/IEEE754.html
+
+# avoid division by zero when calculating scale
+# TODO: align this value with NVIDIA's assumptions (current value is a guess)
+EPS = 1e-12
+
+IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
+FP8_TYPES = {
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.float8_e4m3fnuz,
+ torch.float8_e5m2fnuz,
+}
+
+
+# User defined type for using the individual F8 type based on config
+e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
+e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
+
+
+@torch.no_grad()
+def amax_to_scale(
+ amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
+):
+ """Converts the amax value of a tensor to the fp8 scale.
+ Args:
+ amax: The amax value of the tensor.
+ float8_dtype: The float8 dtype.
+ orig_dtype: The original dtype of the tensor.
+ """
+ scale = torch.empty_like(amax, dtype=torch.float32)
+ if float8_dtype in FP8_TYPES:
+ res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
+ else:
+ raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
+
+ # Ensure that the scale is representable in float16,
+ # this helps when amax is small. We are assuming that we don't need
+ # to care about this for float32/bfloat16.
+ if orig_dtype is torch.float16:
+ res = torch.clamp(res, max=torch.finfo(torch.float16).max)
+ scale.copy_(res)
+ return scale
+
+
+@torch.no_grad()
+def amax_history_to_scale(
+ amax_history: torch.Tensor,
+ float8_dtype: torch.Tensor,
+ orig_dtype: torch.dtype,
+ history_to_scale_fn_type: Literal["max"],
+):
+ """Takes in a history of amax values and returns a scale tensor.
+ Args:
+ amax_history: A tensor containing the history of amax values.
+ float8_dtype: The float8 dtype.
+ orig_dtype: The original dtype of the tensor.
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
+ """
+ if history_to_scale_fn_type == "max":
+ amax = torch.max(amax_history)
+ return amax_to_scale(amax, float8_dtype, orig_dtype)
+ raise NotImplementedError()
+
+
+@torch.no_grad()
+def amax_history_to_scale_stack(
+ amax_history: torch.Tensor,
+ float8_dtype: torch.dtype,
+ orig_dtype: torch.dtype,
+ history_to_scale_fn_type: Literal["max"],
+) -> torch.Tensor:
+ """Takes in a stack of amax_history tensors and returns a scale tensor.
+ Args:
+ amax_history: A 2D tensor containing a stack of amax histories.
+ float8_dtype: The float8 dtype.
+ orig_dtype: The original dtype of the tensor.
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
+ """
+ if history_to_scale_fn_type == "max":
+ amax_stack = torch.max(amax_history, dim=1).values
+ return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
+ raise NotImplementedError(
+ f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
+ )
+
+
+@torch.no_grad()
+def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
+ amax = torch.max(torch.abs(x))
+
+ # If the user asked for distributed reduction, do it.
+ # If the user did not ask for it, assume that it will
+ # happen elsewhere.
+ if reduce_amax and dist.is_initialized():
+ dist.all_reduce(amax, op=dist.ReduceOp.MAX)
+
+ return amax
+
+
+@torch.no_grad()
+def tensor_to_scale(
+ x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
+) -> torch.Tensor:
+ amax = tensor_to_amax(x, reduce_amax=reduce_amax)
+ return amax_to_scale(amax, float8_dtype, x.dtype)
+
+
+def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
+ """Converts a tensor to a saturated fp8 tensor.
+
+ Note:
+ The default behavior in PyTorch for casting to `float8_e4m3fn`
+ and `e5m2` is to not saturate. In this context, we should saturate.
+ A common case where we want to saturate is when the history of a
+ tensor has a maximum value of `amax1`, and the current amax value
+ is `amax2`, where `amax1 < amax2`. This is common when using delayed
+ scaling.
+ """
+ if float8_dtype in FP8_TYPES:
+ max_value = torch.finfo(float8_dtype).max
+ x = x.clamp(min=-max_value, max=max_value)
+ return x.to(float8_dtype)
+ else:
+ raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
+
+
+def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """Computes the error between two tensors in dB.
+
+ For more details see:
+ https://en.wikipedia.org/wiki/Signal-to-noise_ratio
+
+ Args:
+ x: The original tensor.
+ y: The tensor to compare to the original tensor.
+ """
+ Ps = torch.norm(x)
+ Pn = torch.norm(x - y)
+ return 20 * torch.log10(Ps / Pn)
+
+
+def fp8_tensor_statistics(
+ tensor: torch.Tensor, float8_dtype=e4m3_dtype
+) -> Tuple[int, ...]:
+ """Calculate FP8 tensor stats
+
+ Args:
+ tensor: The tensor to calculate stats for.
+ float8_dtype: The float8 dtype.
+
+ Returns:
+ A tuple containing the number of zeros and the number of max values.
+ """
+ if float8_dtype in FP8_TYPES:
+ FP8_MAX = torch.finfo(float8_dtype).max
+ else:
+ raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
+ tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype)
+ num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item()
+ num_zero = (tensor_orig_type == 0).sum().item()
+ return (num_zero, num_max)
+
+
+def is_row_major(stride):
+ assert len(stride) == 2, "is_row_major only supports 2D tensors"
+ return stride[0] > stride[1] and stride[1] == 1
+
+
+def _get_min_alignment(size: int, alignment_value: int) -> int:
+ """
+ Returns the minimum alignment value that is greater than or equal to the given size.
+
+ Args:
+ size: The size of the data to be aligned.
+ alignment_value: The alignment value to be used.
+
+ Returns:
+ int: The minimum alignment value that is greater than or equal to the given size.
+
+ Usage:
+ ```
+ >>> _get_min_alignment(10, 8)
+ 16
+ ```
+ """
+ if size % alignment_value == 0:
+ return size
+ return (1 + (size // alignment_value)) * alignment_value
+
+
+def pad_tensor_for_matmul(
+ tensor: torch.Tensor, dims: Union[int, Iterable[int]]
+) -> torch.Tensor:
+ """
+ Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm`
+
+ Args:
+ tensor: The tensor to pad.
+ both: Whether to pad both dimensions or just the second dimension.
+
+ Returns:
+ torch.Tensor: The padded tensor.
+
+ Usage:
+ ```
+ >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape
+ torch.Size([16, 10])
+ >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape
+ torch.Size([10, 16])
+ >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape
+ torch.Size([16, 16])
+ ```
+ """
+ assert tensor.dim() == 2
+ dim1, dim2 = tensor.shape
+
+ if isinstance(dims, int):
+ dims = (dims,)
+
+ # Calculate aligned dimensions based on the specified dims
+ dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
+ dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2
+
+ # Check if padding is needed for either dimension
+ if dim1 == dim1_aligned and dim2 == dim2_aligned:
+ return tensor
+
+ # Calculate padding values for both dimensions
+ pad_dim1 = dim1_aligned - dim1
+ pad_dim2 = dim2_aligned - dim2
+
+ return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py
new file mode 100644
index 000000000..5f53f5d82
--- /dev/null
+++ b/torchao/float8/fsdp_utils.py
@@ -0,0 +1,388 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.utils._pytree as pytree
+from torchao.float8.float8_scaling_utils import (
+ hp_tensor_to_float8_delayed,
+ hp_tensor_to_float8_dynamic,
+)
+
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+)
+
+from torchao.float8.float8_utils import e4m3_dtype, EPS
+from torch._prims_common import suggest_memory_format
+
+
+@torch.no_grad()
+def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
+ """
+ Calculate scale dynamically for all float8 parameters.
+ This should be run after the optimizer step. It performs a single all-reduce to compute the
+ scales for all float8 weights.
+ Example usage:
+ model(input).sum().backward()
+ optim.step()
+ precompute_float8_dynamic_scale_for_fsdp(model)
+ """
+ from torchao.float8.config import ScalingType
+ from torchao.float8.float8_linear import Float8Linear
+ from torch.distributed._tensor import DTensor
+
+ if any(
+ isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED
+ for m in module.modules()
+ ):
+ raise NotImplementedError("Only supports delayed scaling")
+ float8_linears: List[Float8Linear] = [
+ m
+ for m in module.modules()
+ if isinstance(m, Float8Linear)
+ and isinstance(m.weight, DTensor)
+ and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
+ ]
+ weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]
+
+ if not weights:
+ return
+
+ # inf-norm is equivalent to max(abs(w))
+ max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
+ amax_tensor = torch.stack(max_weights) # Partial
+ # clamp is dispatched through DTensor
+ # it will issue a single all-reduce
+ amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
+ scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
+ if amax_tensor.dtype is torch.float16:
+ scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
+ local_scale_tensor = scale_tensor.to_local()
+ for i, float8_linear in enumerate(float8_linears):
+ float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
+
+
+# FSDP pads its local tensor on dim-0. The subclass should be preserved such
+# that the padded local tensor (and any transformations like copying to GPU)
+# is of the subclass as well.
+_ops_to_preserve_subclass = {
+ torch.ops.aten.empty_like.default,
+ torch.ops.aten.new_zeros.default,
+ torch.ops.aten.slice.Tensor,
+ torch.ops.aten.copy_.default,
+ torch.ops.aten.view.default,
+ torch.ops.aten.as_strided.default,
+ torch.ops.aten._to_copy.default,
+ torch.ops.aten._pin_memory.default,
+}
+
+
+class WeightWithDynamicFloat8CastTensor(torch.Tensor):
+ @staticmethod
+ def __new__(
+ cls,
+ tensor: torch.Tensor,
+ linear_mm_config: LinearMMConfig,
+ precomputed_scale: Optional[torch.Tensor] = None,
+ ):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ tensor.size(),
+ strides=tensor.stride(),
+ storage_offset=tensor.storage_offset(),
+ memory_format=suggest_memory_format(tensor),
+ dtype=tensor.dtype,
+ layout=tensor.layout,
+ device=tensor.device,
+ pin_memory=tensor.is_pinned(),
+ requires_grad=tensor.requires_grad,
+ )
+
+ def __init__(
+ self,
+ tensor: torch.Tensor,
+ linear_mm_config: LinearMMConfig,
+ precomputed_scale: Optional[torch.Tensor] = None,
+ ):
+ self._tensor = tensor
+ self._linear_mm_config = linear_mm_config
+ # for dynamic scaling
+ # `precompute_float8_dynamic_scale_for_fsdp` calculates scales
+ # for all float8 parameters after optimizer step
+ self._precomputed_scale = precomputed_scale
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs=None):
+ if func == torch.ops.aten.detach.default:
+ return WeightWithDynamicFloat8CastTensor(
+ args[0]._tensor, args[0]._linear_mm_config
+ )
+ mm_config: Optional[LinearMMConfig] = None
+
+ def unwrap(t):
+ nonlocal mm_config
+ if mm_config is None:
+ mm_config = t._linear_mm_config
+ else:
+ assert t._linear_mm_config == mm_config
+ return t._tensor
+
+ args, kwargs = pytree.tree_map_only(
+ WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
+ )
+ out = func(*args, **kwargs)
+ if func not in _ops_to_preserve_subclass:
+ return out
+ return pytree.tree_map_only(
+ torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
+ )
+
+ def __tensor_flatten__(self):
+ if self._precomputed_scale:
+ return ["_tensor", "_precomputed_scale"], self._linear_mm_config
+ else:
+ return ["_tensor"], self._linear_mm_config
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
+ mm_config = flatten_spec
+ return WeightWithDynamicFloat8CastTensor(
+ inner_tensors["_tensor"],
+ mm_config,
+ getattr(inner_tensors, "_precomputed_scale", None),
+ )
+
+ def __repr__(self):
+ return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})"
+
+ def fsdp_pre_all_gather(self, mesh):
+ if self._precomputed_scale is not None:
+ float8_tensor = hp_tensor_and_scale_to_float8(
+ self._tensor,
+ self._precomputed_scale,
+ torch.float8_e4m3fn,
+ self._linear_mm_config,
+ GemmInputRole.WEIGHT,
+ )
+ else:
+ float8_tensor = hp_tensor_to_float8_dynamic(
+ self._tensor,
+ e4m3_dtype,
+ self._linear_mm_config,
+ reduce_amax=True,
+ gemm_input_role=GemmInputRole.WEIGHT,
+ )
+ return (float8_tensor._data,), (float8_tensor._scale,)
+
+ def fsdp_post_all_gather(
+ self,
+ all_gather_outputs: Tuple[torch.Tensor, ...],
+ metadata: Any,
+ param_dtype: torch.dtype,
+ *,
+ out: Optional[torch.Tensor] = None,
+ ):
+ (data,) = all_gather_outputs
+ (scale,) = metadata
+ if out is not None:
+ assert isinstance(out, Float8Tensor), f"{type(out)}"
+ out._scale = scale
+ return
+ return Float8Tensor(
+ data,
+ scale,
+ param_dtype,
+ self._linear_mm_config,
+ gemm_input_role=GemmInputRole.WEIGHT,
+ ), (data,)
+
+
+class WeightWithDelayedFloat8CastTensor(torch.Tensor):
+ @staticmethod
+ def __new__(
+ cls,
+ tensor: torch.Tensor,
+ amax_buffer: torch.Tensor,
+ amax_history_buffer: torch.Tensor,
+ scale_buffer: torch.Tensor,
+ linear_mm_config: LinearMMConfig,
+ is_amax_initialized: bool,
+ ):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ tensor.size(),
+ strides=tensor.stride(),
+ storage_offset=tensor.storage_offset(),
+ memory_format=suggest_memory_format(tensor),
+ dtype=tensor.dtype,
+ layout=tensor.layout,
+ device=tensor.device,
+ pin_memory=tensor.is_pinned(),
+ requires_grad=tensor.requires_grad,
+ )
+
+ def __init__(
+ self,
+ tensor: torch.Tensor,
+ amax_buffer: torch.Tensor,
+ amax_history_buffer: torch.Tensor,
+ scale_buffer: torch.Tensor,
+ linear_mm_config: LinearMMConfig,
+ is_amax_initialized: bool,
+ ):
+ self._tensor = tensor
+ self._amax_buffer = amax_buffer
+ self._amax_history_buffer = amax_history_buffer
+ self._scale_buffer = scale_buffer
+ self._linear_mm_config = linear_mm_config
+
+ # Note: is_amax_initialized is not a buffer to avoid data dependent
+ # control flow visible to dynamo
+ # TODO(future PR): add serialization for this flag
+ self.is_amax_initialized = is_amax_initialized
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs=None):
+ if func == torch.ops.aten.detach.default:
+ return WeightWithDelayedFloat8CastTensor(
+ args[0]._tensor,
+ args[0]._amax_buffer,
+ args[0]._amax_history_buffer,
+ args[0]._scale_buffer,
+ args[0]._linear_mm_config,
+ args[0].is_amax_initialized,
+ )
+ mm_config: Optional[LinearMMConfig] = None
+ amax_buffer: Optional[torch.Tensor] = None
+ amax_history_buffer: Optional[torch.Tensor] = None
+ scale_buffer: Optional[torch.Tensor] = None
+ is_amax_initialized: Optional[bool] = None
+
+ def unwrap(t):
+ nonlocal mm_config
+ if mm_config is None:
+ mm_config = t._linear_mm_config
+ else:
+ assert t._linear_mm_config == mm_config
+ nonlocal amax_buffer
+ if amax_buffer is None:
+ amax_buffer = t._amax_buffer
+ nonlocal amax_history_buffer
+ if amax_history_buffer is None:
+ amax_history_buffer = t._amax_history_buffer
+ nonlocal scale_buffer
+ if scale_buffer is None:
+ scale_buffer = t._scale_buffer
+ nonlocal is_amax_initialized
+ if is_amax_initialized is None:
+ is_amax_initialized = t.is_amax_initialized
+ return t._tensor
+
+ args, kwargs = pytree.tree_map_only(
+ WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {})
+ )
+ out = func(*args, **kwargs)
+ if func not in _ops_to_preserve_subclass:
+ return out
+ return pytree.tree_map_only(
+ torch.Tensor,
+ lambda x: WeightWithDelayedFloat8CastTensor(
+ x,
+ amax_buffer,
+ amax_history_buffer,
+ scale_buffer,
+ mm_config,
+ is_amax_initialized,
+ ),
+ out,
+ )
+
+ def __tensor_flatten__(self):
+ return (
+ [
+ "_tensor",
+ "_amax_buffer",
+ "_amax_history_buffer",
+ "_scale_buffer",
+ ],
+ {
+ "mm_config": self._linear_mm_config,
+ "is_amax_initialized": self.is_amax_initialized,
+ },
+ )
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
+ return WeightWithDelayedFloat8CastTensor(
+ inner_tensors["_tensor"],
+ inner_tensors["_amax_buffer"],
+ inner_tensors["_amax_history_buffer"],
+ inner_tensors["_scale_buffer"],
+ metadata["mm_config"],
+ metadata["is_amax_initialized"],
+ )
+
+ def __repr__(self):
+ return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})"
+
+ def fsdp_pre_all_gather(self, mesh):
+ # initialize if needed
+ # TODO(before land): ensure settings are consistent between Float8Linear and here
+ if not self.is_amax_initialized:
+ from torchao.float8.float8_linear import (
+ _maybe_initialize_amaxes_scales_for_float8_cast,
+ )
+
+ _maybe_initialize_amaxes_scales_for_float8_cast(
+ self._tensor,
+ self._amax_buffer,
+ self._amax_history_buffer,
+ self._scale_buffer,
+ "max", # TODO(before land): read this from parent
+ e4m3_dtype,
+ self.is_amax_initialized,
+ reduce_amax=True,
+ )
+ self.is_amax_initialized = True
+
+ float8_tensor = hp_tensor_to_float8_delayed(
+ self._tensor,
+ self._scale_buffer,
+ e4m3_dtype,
+ self._amax_buffer,
+ self._linear_mm_config,
+ GemmInputRole.WEIGHT,
+ )
+ return (float8_tensor._data,), (float8_tensor._scale,)
+
+ def fsdp_post_all_gather(
+ self,
+ all_gather_outputs: Tuple[torch.Tensor, ...],
+ metadata: Any,
+ param_dtype: torch.dtype,
+ *,
+ out: Optional[torch.Tensor] = None,
+ ):
+ (data,) = all_gather_outputs
+ (scale,) = metadata
+ if out is not None:
+ assert isinstance(out, Float8Tensor), f"{type(out)}"
+ out._scale = scale
+ return
+ return Float8Tensor(
+ data,
+ scale,
+ param_dtype,
+ self._linear_mm_config,
+ gemm_input_role=GemmInputRole.WEIGHT,
+ ), (data,)
diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py
new file mode 100644
index 000000000..f5c504503
--- /dev/null
+++ b/torchao/float8/inference.py
@@ -0,0 +1,244 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD 3-Clause license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Defines an nn module designed to be used during inference
+"""
+
+from dataclasses import dataclass
+
+from enum import auto, Enum
+from typing import Callable, List, Optional
+
+import torch
+import torch.nn as nn
+from torchao.float8.float8_linear_utils import swap_linear_layers
+
+from torchao.float8.float8_tensor import (
+ Float8Tensor,
+ GemmInputRole,
+ hp_tensor_and_scale_to_float8,
+ LinearMMConfig,
+ ScaledMMConfig,
+ tensor_already_casted_to_fp8,
+)
+from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
+
+
+class ActivationCasting(Enum):
+ """Types of quantization to perform on the activations
+
+ WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
+ STATIC: Activation is quantized during model initialization with a static scale
+ DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
+ """
+
+ # TODO: A better name would be NONE, we should unify this with torchao
+ WEIGHT_ONLY = auto()
+ DYNAMIC = auto()
+ STATIC = auto()
+
+
+@dataclass(frozen=True)
+class QuantConfig:
+ """Defines the configuration for the quantization to fp8 of a linear module
+
+ Args:
+ activation_casting: The type of quantization to perform on the activations
+ static_quantization_scale: The scale of the input to this linear module, used for static quantization only
+ """
+
+ activation_casting: ActivationCasting
+ static_quantization_scale: Optional[torch.Tensor] = None
+
+ # If True, then prior to performing the fp8 scaled mamtmul we will pad the
+ # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
+ # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
+ # This can cause a memory spike however so we keep this off by default.
+ pad_inner_dim = False
+
+ def __post_init__(self):
+ if self.activation_casting == ActivationCasting.STATIC:
+ assert isinstance(
+ self.static_quantization_scale, torch.Tensor
+ ), "When activation_casting is 'static', activation_scale must be a tensor."
+
+
+class Float8InferenceLinear(torch.nn.Linear):
+ """
+ This is a wrapper around torch.nn.Linear that supports FP8 inference
+ Supported forms of inference:
+ - FP8 inference with high precision matmul - weight only
+ - FP8 inference with fp8 matmul and dynamic weight casting
+ - FP8 inference with fp8 matmul and static weight casting
+ """
+
+ def __init__(
+ self,
+ # FP8 specific arguments
+ quant_config: QuantConfig,
+ linear_mm_config: LinearMMConfig,
+ # nn.Linear arguments
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
+ # Construct the superclass this will create dummy weights and biases
+ super().__init__(in_features, out_features, bias, device, dtype)
+ self.linear_mm_config = linear_mm_config
+ self.activation_casting = quant_config.activation_casting
+ if self.activation_casting == ActivationCasting.STATIC:
+ self.register_buffer(
+ "static_quantization_scale", quant_config.static_quantization_scale
+ )
+ else:
+ self.static_quantization_scale = None
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if self.activation_casting == ActivationCasting.WEIGHT_ONLY:
+ return torch.nn.functional.linear(
+ input, self.weight.to_original_precision()
+ )
+
+ x_fp8 = cast_to_float8_e4m3_inference(
+ input,
+ self.linear_mm_config,
+ static_quantization_scale=self.static_quantization_scale,
+ )
+ return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
+
+ # Builder functions for Float8LinearInference
+ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
+ """This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
+
+ Args:
+ dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
+
+ Note:
+ This function is typically called during inference to quantize the weight once since
+ the weight is not updated during inference.
+
+ """
+ assert not isinstance(
+ self.weight, Float8Tensor
+ ), "Weight has already been quantized, cannot quantize again."
+ scale = tensor_to_scale(self.weight, dtype)
+ quantized_weight = hp_tensor_and_scale_to_float8(
+ self.weight,
+ scale,
+ dtype,
+ self.linear_mm_config,
+ GemmInputRole.WEIGHT,
+ )
+ self.weight = nn.Parameter(quantized_weight)
+ self.weight.requires_grad = False
+
+ def set_weight_and_bias(
+ self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter]
+ ):
+ self.weight = weight
+ self.bias = bias
+
+ @classmethod
+ def from_float(
+ cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool
+ ) -> "Float8InferenceLinear":
+ """
+ Create an nn.Linear with fp8 compute from another nn.Linear
+
+ Args:
+ mod (torch.nn.Linear): nn.Linear to convert
+ quant_config (QuantConfig): Configuration for the weight and activation casting
+ """
+ forward_config = ScaledMMConfig(
+ False, use_fast_accum, pad_inner_dim=quant_config.pad_inner_dim
+ )
+ linear_mm_config = LinearMMConfig(
+ forward_config, forward_config, forward_config
+ )
+ linear = cls(
+ quant_config,
+ linear_mm_config,
+ module.in_features,
+ module.out_features,
+ False,
+ device=torch.device("meta"),
+ )
+ linear.set_weight_and_bias(module.weight, module.bias)
+ linear.quantize_weight()
+ return linear
+
+
+def cast_to_float8_e4m3_inference(
+ inpt_tensor: torch.Tensor,
+ linear_mm_config: LinearMMConfig,
+ reduce_amax: bool = False,
+ static_quantization_scale: Optional[torch.Tensor] = None,
+) -> Float8Tensor:
+ """Casts an input tensor to the Float8 (e4m3fn*)
+
+ Args:
+ inpt_tensor: The input tensor to be cast.
+ linear_mm_config: Configuration settings for the matrix multiplication
+ reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
+ static_quantization_scale: Optional tensor specifying the scale for activation. Default is None.
+
+ Returns:
+ Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
+
+ Note:
+ If the input tensor is already in Float8 format, it is returned as is without re-casting.
+ """
+ if tensor_already_casted_to_fp8(inpt_tensor):
+ return inpt_tensor
+ scale = (
+ static_quantization_scale
+ if static_quantization_scale is not None
+ else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
+ )
+ return hp_tensor_and_scale_to_float8(
+ inpt_tensor,
+ scale,
+ e4m3_dtype,
+ linear_mm_config,
+ GemmInputRole.INPUT,
+ )
+
+
+def quantize_to_float8(
+ module: nn.Module,
+ quant_config: QuantConfig,
+ *,
+ module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
+ use_fast_accum: bool = True,
+) -> nn.Module:
+ """
+ Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
+
+ Note:
+ If applied to a root-level nn.Linear, the module will not be modified in place
+ and returned instead
+
+ Args:
+ module (nn.Module): The module to modify.
+ quant_config (QuantConfig): Quantization configuration for Float8 conversion.
+ module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
+ that pass the filter function will be swapped. The inputs to the
+ filter function are the module instance and the FQN.
+ use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
+
+ Returns:
+ nn.Module: The modified module with applicable Linear layers converted to Float8.
+
+ Raises:
+ AssertionError: If a root-level nn.Linear with children is encountered.
+ """
+ return swap_linear_layers(
+ module,
+ lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
+ module_filter_fn=module_filter_fn,
+ )