diff --git a/README.md b/README.md index e31dc63a8..8905b3f43 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,12 @@ In some cases we rewrote popular GenAI models to be significantly faster in nati ### Training +#### Float8 + +[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. + +#### Sparsity + We've added support for semi-structured 2:4 sparsity with 6% end to end speedups on ViT-L The code change is a 1 liner with the full example available [here](torchao/sparsity/training/) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py new file mode 100644 index 000000000..b44d4f5dc --- /dev/null +++ b/benchmarks/float8/bench_linear_float8.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import copy +from dataclasses import dataclass +from itertools import product +from pathlib import Path +from typing import Callable, List, Optional, Tuple + +import pandas as pd + +import torch +import torch.utils.benchmark as benchmark +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import ( + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_tensor import ScaledMMConfig +from tqdm import tqdm + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + +# prevent splitting columns when printing a data frame +pd.set_option("display.expand_frame_repr", False) +# print the entire data frame +pd_print_full_ctx = pd.option_context( + "display.max_rows", None, "display.max_columns", None +) + + +def benchmark_torch_function_in_microseconds( + func: Callable, + *args, + **kwargs, +) -> float: + t0 = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + return t0.blocked_autorange().median * 1e6 + + +@dataclass +class Experiment: + name: str + shape: Tuple[int, int, int] + ref_time_sec: float + float8_time_sec: float + dtype: torch.dtype + compiled: bool + use_fast_accum: bool + scaling_repr: str + + # 3 Times since we are calculating forward backward + @property + def ref_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.ref_time_sec + + @property + def ref_pct_top_peak(self): + return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] + + @property + def float8_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.float8_time_sec + + @property + def float8_pct_top_peak(self): + return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn] + + +def main( + sweep_path: Optional[Path] = None, + compile: bool = True, + n_limit: Optional[int] = None, + fast_accum_filter: Optional[bool] = None, + shape_name_filter: Optional[str] = None, + scaling_type_input: str = "dynamic", + scaling_type_weight: str = "dynamic", + scaling_type_grad_output: str = "dynamic", +): + device = "cuda" + print(f"Compile is set to | {compile}") + + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + ) + + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + name_to_shapes_70b = { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + } + input_bias = False + if fast_accum_filter is not None: + use_fast_accum = [fast_accum_filter] + else: + use_fast_accum = [True, False] + if shape_name_filter is not None: + k = shape_name_filter + name_to_shapes_70b = {k: name_to_shapes_70b[k]} + experiment_list: List[Experiment] = [] + dtype = torch.bfloat16 + for idx, (fast_accum, (name, (K, N))) in enumerate( + tqdm(list(product(use_fast_accum, name_to_shapes_70b.items()))) + ): + if n_limit is not None and idx >= n_limit: + break + linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( + device=device, dtype=dtype + ) + + linear_float8 = Float8Linear.from_float( + copy.deepcopy(linear_ref), + config=config, + ) + scaling_repr = linear_float8.scaling_repr() + + if fast_accum: + linear_float8.forward_config = ScaledMMConfig(False, True, False) + else: + linear_float8.forward_config = ScaledMMConfig(False, False, False) + + bsz, seq_len = 4, 4096 + M = bsz * seq_len + input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) + ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() + + def float8_forw_backward(): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(linear_float8) + linear_float8(input_tensor).sum().backward() + + def n_times(n, fn, *args, **kwargs): + def wrapper(*args, **kwargs): + for _ in range(n): + fn(*args, **kwargs) + + return wrapper + + REPEAT_N = 100 + + ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) + float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) + + if compile: + ref_forw_backward = torch.compile(ref_forw_backward) + float8_forw_backward = torch.compile(float8_forw_backward) + + for _ in range(5): + ref_forw_backward() + float8_forw_backward() + + ref_time = ( + benchmark_torch_function_in_microseconds(ref_forw_backward) + * 1e-6 + / REPEAT_N + ) + float8_time = ( + benchmark_torch_function_in_microseconds(float8_forw_backward) + * 1e-6 + / REPEAT_N + ) + experiment = Experiment( + name, + (M, K, N), + ref_time, + float8_time, + dtype, + compile, + use_fast_accum=fast_accum, + scaling_repr=scaling_repr, + ) + print(experiment) + print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) + experiment_list.append(experiment) + torch._dynamo.reset() + + headers = [ + "name", + "M", + "K", + "N", + "scaling_repr", + "ref_dtype", + "compiled", + "use_fast_accum", + "ref_time_sec", + "pt_fp8_time_sec", + "ref_tops_sec", + "ref_pct_top_peak", + "pt_fp8_tops_sec", + "pt_fp8_pct_top_peak", + ] + data = [] + for experiment in experiment_list: + data.append( + [ + experiment.name, + experiment.shape[0], + experiment.shape[1], + experiment.shape[2], + experiment.scaling_repr, + experiment.dtype, + experiment.compiled, + experiment.use_fast_accum, + experiment.ref_time_sec, + experiment.float8_time_sec, + experiment.ref_tops_sec, + experiment.ref_pct_top_peak, + experiment.float8_tops_sec, + experiment.float8_pct_top_peak, + ] + ) + + data_pd = pd.DataFrame(data, columns=headers) + data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] + data_pd["shape"] = ( + "(" + + data_pd["M"].astype(str) + + ", " + + data_pd["K"].astype(str) + + ", " + + data_pd["N"].astype(str) + + ")" + ) + + data_pd_simple = data_pd[ + [ + "name", + "shape", + "scaling_repr", + "compiled", + "use_fast_accum", + "ref_time_sec", + "pt_fp8_time_sec", + "pt_fp8_speedup", + ] + ] + with pd_print_full_ctx: + print(data_pd_simple) + + if sweep_path is not None: + sweep_path = sweep_path.with_suffix(".csv") + data_pd.to_csv(sweep_path) + + +def invoke_main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output_path", type=str, required=False) + parser.add_argument("--disable_compile", action="store_true") + parser.add_argument("-n", "--n_limit", type=int, required=False) + parser.add_argument("--fast_accum_filter", type=bool, required=False) + parser.add_argument("--shape_name_filter", type=str, required=False) + parser.add_argument("--scaling_type_input", type=str, required=False) + parser.add_argument("--scaling_type_weight", type=str, required=False) + parser.add_argument("--scaling_type_grad_output", type=str, required=False) + args = parser.parse_args() + output_path = Path(args.output_path) if args.output_path is not None else None + kwargs = {} + if args.scaling_type_input is not None: + kwargs["scaling_type_input"] = args.scaling_type_input + if args.scaling_type_weight is not None: + kwargs["scaling_type_weight"] = args.scaling_type_weight + if args.scaling_type_grad_output is not None: + kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output + main( + output_path, + not args.disable_compile, + args.n_limit, + args.fast_accum_filter, + args.shape_name_filter, + **kwargs, + ) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py new file mode 100644 index 000000000..6220670ee --- /dev/null +++ b/benchmarks/float8/bench_matmul.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import itertools +from typing import Optional + +import fire +import pandas as pd + +import torch +import torch.nn as nn +import torch.utils.benchmark as benchmark + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 989e12 +h100_peak_tops_float8_tc = 1979e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_fn_in_sec(f, *args, **kwargs): + # Manual warmup + for _ in range(4): + f(*args, **kwargs) + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + measurement = t0.blocked_autorange() + return measurement.mean + + +def do_benchmarks(tops, peak_tops, f, *args, **kwargs): + time_sec = benchmark_fn_in_sec(f, *args, **kwargs) + tops_sec = float(tops) / time_sec + pct_top_peak = tops_sec / peak_tops + return time_sec, tops_sec, pct_top_peak + + +@torch.inference_mode() +def run(n_limit: Optional[int] = None): + device = "cuda" + + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + # source: https://fburl.com/gsheet/g8onr7rh + name_to_shapes_70b = { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + } + + headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup") + results = [] + + name_to_shapes = name_to_shapes_70b + dtypes = torch.bfloat16, torch.float16 + + for idx, (dtype, (name, (K, N))) in enumerate( + itertools.product(dtypes, name_to_shapes.items()) + ): + if n_limit is not None and idx >= n_limit: + break + + # source: Xiao Sun, these are realistic for LLaMa 70B training + bsz, seq_len = 4, 4096 + + M = bsz * seq_len + print("M, K, N:", M, K, N) + tops = 2 * M * N * K + print(f"tops: {tops:.2E}") + + # raw torch.mm + A = torch.randn(M, K, device=device, dtype=dtype) + m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False)) + ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( + tops, dtype_to_peak_tops[dtype], m_ref, A + ) + print( + f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}" + ) + + del A + + # raw float8 matmul (upper bound for what we can achive in eager mode) + # TODO(future): add e5m2 + d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype + A = torch.zeros(M, K, device=device, dtype=d1) + B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() + + def do_matmul(A, B): + scale_a = torch.tensor([1.0], device=device) + scale_b = torch.tensor([1.0], device=device) + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + + fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( + tops, dtype_to_peak_tops[d1], do_matmul, A, B + ) + print( + f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" + ) + + del A, B + + results.append( + [ + name, + (M, K, N), + dtype, + ref_time_sec, + fp8_time_sec, + ref_time_sec / fp8_time_sec, + ] + ) + + data_pd = pd.DataFrame(results, columns=headers) + print(data_pd) + + +def main() -> None: + fire.Fire(run) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py new file mode 100644 index 000000000..44c758d1b --- /dev/null +++ b/benchmarks/float8/bench_multi_gpu.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Callable + +import fire + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.utils.benchmark as benchmark +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + sync_float8_amax_and_scale_history, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + +torch.manual_seed(0) + +# TODO: Add more shapes for the benchmark +B, M, K, N = 32, 1024, 1024, 1024 +lr = 0.01 + +config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), +) + + +def benchmark_torch_function_in_microseconds( + func: Callable, + *args, + **kwargs, +) -> float: + t0 = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + return t0.blocked_autorange().median * 1e6 + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def get_model(K, N, is_fp8, base_dtype=torch.float32): + modules = [ + nn.Linear(K, N, dtype=base_dtype), + nn.ReLU(), + ] + N_LAYERS = 20 + # N linear layers + for _ in range(N_LAYERS - 1): + modules.append(nn.Linear(N, N, dtype=base_dtype)) + modules.append(nn.ReLU()) + m = nn.Sequential(*modules) + if is_fp8: + convert_to_float8_training( + m, + config=config, + ) + return m + + +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + torch.cuda.set_device(rank) + + base_dtype, input_global, compile = args + + # basic distributed data sampling + assert B % world_size == 0 + bsz_local_start = int(rank / world_size * B) + bsz_local_end = int((rank + 1) / world_size * B) + input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) + + fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) + # Need use_orig_params=True to compile FSDP + fp8_model = FSDP(fp8_model, use_orig_params=True) + fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) + + # Run one iteration to make compile work, see experiments doc for more context of this issue. + fp8_optimizer.zero_grad() + y_local = fp8_model(input_tensor) + y_local.sum().backward() + fp8_optimizer.step() + sync_float8_amax_and_scale_history(fp8_model) + + sync_float8_func = sync_float8_amax_and_scale_history + if compile: + # TODO: Need to fix issues with compile + fp8_model = torch.compile(fp8_model) + sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) + + def float8_forw_backward(): + fp8_optimizer.zero_grad() + y_local = fp8_model(input_tensor) + y_local.sum().backward() + fp8_optimizer.step() + sync_float8_func(fp8_model) + + ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) + if compile: + ref_model = torch.compile(ref_model) + + ref_model = FSDP(ref_model, use_orig_params=True) + + def ref_forw_backward(): + ref_optimizer.zero_grad() + ref_model(input_tensor).sum().backward() + ref_optimizer.step() + + def run_n_iterations(n, fn): + for _ in range(n): + fn() + # make sure training is done on all ranks + dist.barrier() + + # warmup + run_n_iterations(50, ref_forw_backward) + run_n_iterations(50, float8_forw_backward) + + N_ITER = 50 + ref_time = ( + benchmark_torch_function_in_microseconds( + run_n_iterations, N_ITER, ref_forw_backward + ) + * 1e-6 + / N_ITER + ) + float8_time = ( + benchmark_torch_function_in_microseconds( + run_n_iterations, N_ITER, float8_forw_backward + ) + * 1e-6 + / N_ITER + ) + + if rank == 0: + print("ref_time", ref_time) + print("float8_time", float8_time) + print("float8 speedup", ref_time / float8_time) + + cleanup() + + +def run(compile: bool): + base_dtype = torch.bfloat16 + WORLD_SIZE = torch.cuda.device_count() + print(f"{base_dtype = }") + print(f"{compile = }") + print(f"{WORLD_SIZE = }") + + # generate input data + ref_input = torch.randn(B, M, K).cuda().to(base_dtype) + # run fsdp model + args = (base_dtype, ref_input, compile) + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + + +# Usgae: +# CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py +if __name__ == "__main__": + fire.Fire(run) diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py new file mode 100644 index 000000000..977755343 --- /dev/null +++ b/benchmarks/float8/bench_padding.py @@ -0,0 +1,223 @@ +from dataclasses import dataclass +from typing import Optional + +import fire + +import torch +from torchao.float8.float8_tensor import ( + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, + ScaledMMConfig, +) +from torchao.float8.float8_utils import pad_tensor_for_matmul +from tabulate import tabulate +from torch._inductor.utils import do_bench_using_profiling +from tqdm import tqdm + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_fn_in_usec(f, *args, **kwargs): + no_args = lambda: f(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 + + +def get_tops_info(tops, time, peak_tops): + time_sec = time / 1e6 + tops_sec = float(tops) / time_sec + pct_top_peak = tops_sec / peak_tops + return tops_sec, pct_top_peak + + +def do_fp8_matmul(A, B, fp8_dtype, out_dtype): + scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) + scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) + + a_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) + b_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) + a_config = LinearMMConfig(a_config, a_config, a_config) + b_config = LinearMMConfig(b_config, b_config, b_config) + + a_fp8 = hp_tensor_and_scale_to_float8( + A, + scale_a, + fp8_dtype, + a_config, + GemmInputRole.INPUT, + ) + b_fp8 = hp_tensor_and_scale_to_float8( + B, + scale_b, + fp8_dtype, + b_config, + GemmInputRole.WEIGHT, + ) + + return a_fp8 @ b_fp8 + + +def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): + # Breaks with compile due to trying to pad on fp8 dtype + # return do_fp8_matmul(A, B, fp8_dtype, out_dtype) + A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy + B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy + + scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) + scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) + + A_pad = A_pad.to(fp8_dtype) # mem copy + B_pad = B_pad.to(fp8_dtype) # mem copy + + B_pad = B_pad.t().contiguous().t() # mem copy + + return torch._scaled_mm( + A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True + ) + + +def do_hp_matmul(A, B): + return torch.matmul(A, B) + + +def do_aligned_bf16_matmul(A, B): + A_pad = pad_tensor_for_matmul(A, dims=1) + B_pad = pad_tensor_for_matmul(B, dims=0) + return torch.matmul(A_pad, B_pad) + + +@dataclass +class Experiment_config: + M: int + K: int + N: int + output_dtype: torch.dtype + fp8_dtype: torch.dtype + + def __iter__(self): + return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype)) + + +def gen_configs(): + shapes = shapes = [ + (8193, 2501, 5008), + (65, 253, 4096), + (1023, 1029, 2512), + (4095, 511, 10000), + (2047, 3073, 8192), + (511, 769, 7504), + (127, 4097, 12288), + (32769, 15, 15024), + (9217, 8191, 20480), + (16385, 1025, 25008), + ] + output_dtype = torch.bfloat16 + fp8_dtype = torch.float8_e4m3fn + return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] + + +@torch.no_grad() +def run(compile: bool = False, n_limit: Optional[int] = None): + device = "cuda" + experiments = gen_configs() + results = [] + tops_table = [] + tops_headers = [ + "Shape", + "Ref Dtype", + "Ref Tops", + "Aligned BF16 Tops", + "FP8 Tops", + "Ref % Peak", + "Aligned BF16 % Peak", + "FP8 % Peak", + ] + + for experiment in tqdm(experiments): + M, K, N, output_dtype, fp8_dtype = experiment + tops = 2 * M * N * K + + A_base = torch.rand(M, K, device=device, dtype=output_dtype) + B_base = torch.rand(K, N, device=device, dtype=output_dtype) + + hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul + aligned_bf16_func = ( + torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul + ) + fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul + + ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base) + aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base) + fp8_time = benchmark_fn_in_usec( + fp8_func, A_base, B_base, fp8_dtype, output_dtype + ) + + ref_tops_sec, ref_pct_top_peak = get_tops_info( + tops, ref_time, dtype_to_peak_tops[output_dtype] + ) + aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info( + tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16] + ) + fp8_tops_sec, fp8_pct_top_peak = get_tops_info( + tops, fp8_time, dtype_to_peak_tops[fp8_dtype] + ) + tops_table.append( + [ + f"({M}x{K}x{N})", + f"{output_dtype}", + f"{ref_tops_sec:.2E}", + f"{aligned_bf16_tops_sec:.2E}", + f"{fp8_tops_sec:.2E}", + f"{ref_pct_top_peak:.3f}", + f"{aligned_bf16_pct_top_peak:.3f}", + f"{fp8_pct_top_peak:.3f}", + ] + ) + results.append( + [ + (M, K, N), + output_dtype, + ref_time, + aligned_bf16_time, + fp8_time, + ref_time / aligned_bf16_time, + ref_time / fp8_time, + ] + ) + + print("TOPs".center(80, "*")) + print(tabulate(tops_table, headers=tops_headers)) + print("Speed Results".center(80, "*")) + headers = [ + "Shape", + "Ref Dtype", + "Ref Time", + "Aligned BF16 Time", + "FP8 Time", + "Aligned BF16 Speedup", + "FP8 Speedup", + ] + print(tabulate(results, headers=headers, tablefmt="grid")) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py new file mode 100644 index 000000000..914759849 --- /dev/null +++ b/benchmarks/float8/profile_linear_float8.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import random +from contextlib import nullcontext, redirect_stdout +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Optional + +import fire +import pandas as pd + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torch.profiler import profile, ProfilerActivity, record_function +from utils import ( + kernel_name_to_category, + parse_bw_and_kernel_name, + profiler_output_to_gpu_time_for_key, + profiler_output_to_time_by_kernel_name, +) + +# don't truncate long kernel names +pd.options.display.max_colwidth = 100 +# display 3 trailing decimal points for floats +pd.set_option("display.float_format", "{:.3f}".format) + + +class LNLinear(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2): + super().__init__() + self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False) + + def forward(self, x): + x = self.ln(x) + x = self.fc(x) + return x + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class NormFFNResidualNorm(nn.Module): + """ + A fragment representing the end of TransformerBlock n and the start + of TransformerBlock n + 1, intended to include the fusions relevant + to float8 gemms in the FFN module in forward and backward. + """ + + def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier): + super().__init__() + self.ffn_norm = RMSNorm(dim) + self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier) + self.attn_norm = RMSNorm(dim) + + def forward(self, h): + # end of transformer block n + x = self.ffn_norm(h) + x = self.ffn(x) + x = h + x + # start of transformer block n + 1 + x = self.attn_norm(x) + return x + + +@dataclass +class ProfileConfig: + file_path: Optional[str] = None + name: Optional[str] = None + cuda: bool = True + iters: int = 0 + warmup_iters: int = 0 + sync: bool = False + extra_kwargs: dict = field(default_factory=dict) + memory_profile_path: Optional[str] = None + + +def profile_function( + config: ProfileConfig, func: Callable, *args, **kwargs +) -> torch.profiler.profile: + """Profile a torch function and save the result to a file""" + seed = 123 + random.seed(seed) + torch.manual_seed(seed) + + activities = [ProfilerActivity.CPU] + if config.cuda: + activities.append(ProfilerActivity.CUDA) + + if config.warmup_iters >= 0: + for _ in range(config.warmup_iters): + func(*args, **kwargs) + if config.sync: + torch.cuda.synchronize() + name_context = ( + nullcontext() if config.name is None else record_function(config.name) + ) + profile_memory = config.memory_profile_path is not None + with profile( + activities=activities, + profile_memory=profile_memory, + record_shapes=profile_memory, + with_stack=profile_memory, + **config.extra_kwargs, + ) as prof: + for _ in range(config.iters): + with name_context: + func(*args, **kwargs) + if config.sync: + torch.cuda.synchronize() + + if config.file_path is not None: + prof.export_chrome_trace(config.file_path) + + return prof + + +def main( + profile_path_prefix: Path, + compile: bool = True, + scaling_type_input: str = "dynamic", + scaling_type_weight: str = "dynamic", + scaling_type_grad_output: str = "dynamic", + model_type: str = "linear", + dtype_filter: str = "both", +): + assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" + assert dtype_filter in ("both", "float8", "bfloat16") + + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + ) + scaling_repr = "_".join( + [ + s.short_str() + for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output) + ] + ) + + print(f"Compile is set to | {compile}") + print(f"model_type is set to | {model_type}") + print(f"scaling_repr is set to | {scaling_repr}") + + device = "cuda" + ref_dtype = torch.bfloat16 + if model_type == "ln_linear": + M, K, N = 4 * 4096, 8192, 7168 + m_ref = LNLinear(K, N) + input_tensor = torch.randn( + M, K, device=device, dtype=ref_dtype, requires_grad=True + ) + elif model_type == "norm_ffn_norm": + m_ref = NormFFNResidualNorm( + dim=4096, + hidden_dim=16384, + multiple_of=1024, + ffn_dim_multiplier=1.3, + ) + input_tensor = torch.randn( + 1, 8192, 4096, device=device, dtype=ref_dtype + ).requires_grad_() + else: + M, K, N = 4 * 4096, 8192, 7168 + m_ref = torch.nn.Sequential( + torch.nn.Linear(K, N, bias=False), + ) + input_tensor = torch.randn( + M, K, device=device, dtype=ref_dtype, requires_grad=True + ) + + m_ref = m_ref.to(device).to(ref_dtype) + + m_float8 = copy.deepcopy(m_ref) + convert_to_float8_training(m_float8, config=config) + + def ref_forw_backward(x): + out = m_ref(x) + out.sum().backward() + + def float8_forw(x): + out = m_float8(x) + return out + + sync_amax_history = sync_float8_amax_and_scale_history + + def float8_forw_backward_wrapper(x): + # sync_float8_amax_and_scale_history is not full graph torch + # compile friendly, so we add a high level wrapper to allow + # inspection of the fw+bw torch.compile without the scale + # syncing code + # TODO(future): make this better + if linear_requires_sync(config): + with record_function("scale_amax_and_scales"): + sync_amax_history(m_float8) + out = float8_forw(x) + + # out.sum().backward() is also not torch.compile fullgraph + # friendly + with record_function("backward"): + out.sum().backward() + + if compile: + m_ref = torch.compile(m_ref, fullgraph=True) + float8_forw = torch.compile(float8_forw, fullgraph=True) + # Note: it's faster to compile the combination of sync_amax_history wit + # forward because we only look up from dynamo cache once. + # However, compiling the sync function separately makes it more + # convenient to analyze the total time spent on it. + sync_amax_history = torch.compile(sync_amax_history) + + # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output + # to populate triton kernel bandwidth further down in the script + f = io.StringIO() + with redirect_stdout(f): + # warm up + for _ in range(1): + if dtype_filter != "float8": + ref_forw_backward(input_tensor) + if dtype_filter != "bfloat16": + float8_forw_backward_wrapper(input_tensor) + + profile_iters = 5 + ref_times, float8_times = None, None + data = [] + + if dtype_filter != "float8": + # Profile Reference Model + print("profiling ref") + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" + ref_path = profile_path_prefix + ref_suffix + profile_config = ProfileConfig( + ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True + ) + p = profile_function(profile_config, ref_forw_backward, input_tensor) + print(f"saved {ref_path}") + ref_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters + for k, v in ref_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "0_ref", + k, + kernel_name_to_category(k), + v_ms, + v_ms / total_time_ms, + None, + ] + ) + + if dtype_filter != "bfloat16": + # Profile Float8 Model + print("profiling float8") + float8_suffix = ( + f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" + ) + float8_path = profile_path_prefix + float8_suffix + profile_config = ProfileConfig( + float8_path, + float8_suffix, + iters=profile_iters, + warmup_iters=2, + sync=True, + ) + p = profile_function( + profile_config, float8_forw_backward_wrapper, input_tensor + ) + print(f"saved {float8_path}") + float8_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters + for k, v in float8_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "1_float8", + k, + kernel_name_to_category(k), + v / 1e3 / profile_iters, + v_ms / total_time_ms, + None, + ] + ) + + # get the time spent per user annotation + sync_time_us = profiler_output_to_gpu_time_for_key( + p, "scale_amax_and_scales" + ) + sync_time_ms = sync_time_us / profile_iters / 1e3 + print(f"Sync time ms: {sync_time_ms}") + + # print the redirected stdout back to regular stdout + print(f.getvalue()) + + # populate the triton kernel bandwidth + for line in f.getvalue().split("\n"): + maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line) + if maybe_kernel_name is not None: + # O(N) search, but it's ok since lists are small + for datum in data: + if datum[1] == maybe_kernel_name: + datum[-1] = maybe_bw + + df = pd.DataFrame( + data, + columns=[ + "experiment", + "kernel", + "category", + "time_ms", + "pct_gpu_time", + "bw_gpbs", + ], + ) + df.sort_values( + ["experiment", "category", "pct_gpu_time"], + ascending=[True, True, False], + inplace=True, + ) + print("\nSummary of GPU time by CPU kernel\n\n", df) + + # compare gemm and overhead time + df_p = df.pivot_table( + columns=["category"], + index="experiment", + values="time_ms", + aggfunc="sum", + fill_value=0, + margins=True, + ) + # drop last row, which has totals across ref + float8 which does not make sense + df_p = df_p[:-1] + df_p = df_p.transpose() + + if dtype_filter == "both": + df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] + df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] + + # calculate sync time as pct of total float time + # note: this time is not useful if TORCHINDUCTOR_PROFILE is on + total_float8_ms = df_p.iloc[3]["1_float8"] + sync_approx_ratio = sync_time_ms / total_float8_ms + print( + f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" + ) + + print("\nSummary of time (ms) by kernel category\n\n", df_p) + + +def invoke_main() -> None: + # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic" + # You can set TORCHINDUCTOR_PROFILE=1 to also capture triton kernel bandwidth + fire.Fire(main) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py new file mode 100644 index 000000000..aec19e2cd --- /dev/null +++ b/benchmarks/float8/utils.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import re + + +def profiler_output_to_time_by_kernel_name(prof): + """ + Input: a profiler with captured events. + Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name + + Note that if there are user_annotations in the captured events, `torch.profiler` + will include their time in the total GPU time displayed at the bottom of + `key_averages.table()`. The filter below excludes them to prevent double + counting. + """ + key_averages = prof.key_averages() + thresh = 1e-10 + kernel_name_to_gpu_time_us = collections.defaultdict(float) + for e in key_averages: + # manually filter top-level CPU events with attributed CUDA time + # example CPU event row: + # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1 + # and it maps to this CUDA event: + # sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1 + if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh): + continue + kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total + return kernel_name_to_gpu_time_us + + +def profiler_output_to_gpu_time_for_key(prof, key): + """ + Input: an event name + Output: sum of GPU time of all events with that name in `prof` + + This is useful to get the total time of a user annotation + """ + total = 0 + for e in prof.profiler.function_events: + if e.key == key: + total += e.device_time_total + return total + + +def kernel_name_to_category(k): + # number prefix is for easy sorting + if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): + return "0_gemm" + elif ( + # max(abs(tensor)) + ("abs" in k and "max" in k) + or + # casting pointwise to float8 + ("clamp" in k) + or + # things related to scaled_mm + ("scaled_mm" in k) + or + # syncing amaxes and scales + ("roll" in k) + ): + # note: the above filter is approximate and will give false + # positives if model code contains other code to abs/max/clamp + return "1_f8_overhead" + return "2_other" + + +def parse_bw_and_kernel_name(line): + """ + Input: a single line of stdout of TORCHINDUCTOR_PROFILE=1 output, such as + 0.257ms 0.537 GB 2092.43GB/s triton_red_fused_native_layer_norm_0 + Output: the bandwidth value and the kernel name, or None and None + """ + result = re.search(".* ([0-9\.]+)GB/s.*(triton_[a-z_0-9]+)", line) + if result: + return result.group(1), result.group(2) + else: + return None, None diff --git a/test/float8/test_base.py b/test/float8/test_base.py new file mode 100644 index 000000000..0780968aa --- /dev/null +++ b/test/float8/test_base.py @@ -0,0 +1,723 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import copy +import io +import itertools +import random +import re +import unittest +import warnings + +import pytest + +import torch +import torch.nn as nn + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_python_api import addmm_float8_unwrapped +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, + ScaledMMConfig, +) +from torchao.float8.float8_utils import ( + compute_error, + e4m3_dtype, + e5m2_dtype, + fp8_tensor_statistics, + FP8_TYPES, + tensor_to_scale, +) +from torchao.float8.inference import ( + ActivationCasting, + QuantConfig, + quantize_to_float8, +) + +random.seed(0) +torch.manual_seed(0) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: + assert torch.all(a._data == b._data).item(), "scales are not identical" + assert torch.all(a._data == b._data).item(), "data is not identical" + return True + + +class TestFloat8Tensor(unittest.TestCase): + def test_preserves_dtype(self) -> None: + # hp means high precision, lp means low precision + hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) + lp_dtypes = FP8_TYPES + for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): + x1_hp = torch.randn(4, 4, dtype=hp_dtype) + x1_s = tensor_to_scale(x1_hp, lp_dtype) + x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) + x3_hp = x2_lp.to_original_precision() + self.assertTrue(x3_hp.dtype == hp_dtype) + + def test_differentiable_casts(self) -> None: + lp_dtypes = (e4m3_dtype, e5m2_dtype) + for f8_dtype in lp_dtypes: + x = torch.randn(1).requires_grad_() + grad = torch.randn(1) + x_s = tensor_to_scale(x, f8_dtype) + x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype) + x_f8_hp = x_f8.to_original_precision() + x_f8_hp.backward(grad) + # the gradient should be unchanged through both casts + torch.testing.assert_close(grad, x.grad, rtol=0, atol=0) + + def test_split_cat(self): + a = torch.rand(16, 16, dtype=torch.bfloat16) + scale = tensor_to_scale(a, e4m3_dtype) + fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype) + + splits = torch.split(fp8_a, 16) + catted = torch.cat(splits, dim=0) + assert bitwise_identical(fp8_a, catted) + + def test_index_put(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn) + + index = torch.randint(0, 15, (16,), dtype=torch.long) + + b = torch.rand(16, 16, dtype=torch.bfloat16) + scale_b = tensor_to_scale(b, torch.float8_e4m3fn) + fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) + fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) + + with self.assertRaises(AssertionError): + b[index] = fp8_a + fp8_b[index] = a + fp8_b_bad[index] = fp8_a + fp8_b[index] = fp8_a + + def test_copy_(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn) + + b = torch.empty(16, dtype=torch.bfloat16) + b.copy_(fp8_a) # Should work + torch.testing.assert_close(b, fp8_a.to_original_precision()) + with self.assertRaises(RuntimeError): + fp8_a.copy_(b) # Should fail + + fp8_b = Float8Tensor( + torch.empty(16, dtype=torch.float8_e4m3fn), + scale_a, + torch.bfloat16, + fp8_a._linear_mm_config, + ) + fp8_b.copy_(fp8_a) + torch.testing.assert_close(fp8_a._data, fp8_b._data) + + def test_weights_only_load(self): + module = nn.Linear(16, 16) + # Save model state dict + buffer = io.BytesIO() + fp8_module = quantize_to_float8( + module, + QuantConfig( + ActivationCasting.DYNAMIC, + ), + ) + + torch.save(fp8_module.state_dict(), buffer) + buffer.seek(0) + _ = torch.load(buffer, weights_only=True) + + +class TestFloat8Linear: + def _test_linear_impl( + self, + x, + m_ref, + config: Float8LinearConfig, + ): + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) + for _ in range(2): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m_fp8) + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + y_ref = m_ref(x) + y_ref.sum().backward() + + assert y_ref.shape == y_fp8.shape + + y_sqnr = compute_error(y_ref, y_fp8) + g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad) + # verify sqnr is reasonable + assert y_sqnr >= 18.0, f"{y_sqnr} is too low" + assert g_sqnr >= 17.0, f"{g_sqnr} is too low" + if m_ref.bias is not None: + torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) + + # verify all of the amax buffers got updated + if linear_requires_sync(config): + # only check buffers that are actually used, based on per-tensor + # scaling settings + amax_buffer_names = [] + amax_history_buffer_names = [] + scale_buffer_names = [] + if config.cast_config_input.scaling_type is ScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_input") + amax_history_buffer_names.append("fp8_amax_history_input") + scale_buffer_names.append("fp8_scale_input") + if config.cast_config_weight.scaling_type is ScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_weight") + amax_history_buffer_names.append("fp8_amax_history_weight") + scale_buffer_names.append("fp8_scale_weight") + if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_grad_output") + amax_history_buffer_names.append("fp8_amax_history_grad_output") + scale_buffer_names.append("fp8_scale_grad_output") + + # verify all of the amax buffers got updated + max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} + for buffer_name in amax_buffer_names: + buffer_value = getattr(m_fp8, buffer_name) + for init_val in max_float8_pos: + assert torch.ne( + buffer_value, torch.tensor(init_val) + ), f"{buffer_name} not filled, current value {buffer_value}" + + # verify all of the amax history buffers got updated + for buffer_name in amax_history_buffer_names: + buffer_value = getattr(m_fp8, buffer_name) + assert torch.max(buffer_value) > 0.0, f"{buffer_name} not filled" + + # verify all of the scale buffers got updated + for buffer_name in scale_buffer_names: + buffer_value = getattr(m_fp8, buffer_name) + assert torch.ne( + buffer_value, torch.tensor(1.0) + ), f"{buffer_name} not filled, current value {buffer_value}" + + # verify initialization flags got updated + assert m_fp8.is_amax_initialized, "Amax was not properly initialized" + + @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) + @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + @pytest.mark.parametrize( + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC], + ) + @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("linear_bias", [False, True]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_linear( + self, + x_shape, + emulate: bool, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + linear_dtype: torch.dtype, + linear_bias: bool, + ): + if not emulate: + if not torch.cuda.is_available(): + warnings.warn("CUDA not available") + pytest.skip() + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" + ) + pytest.skip() + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + emulate=emulate, + ) + self._test_linear_impl( + x, + m_ref, + config, + ) + + @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) + @pytest.mark.parametrize( + "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_autocast_outputs( + self, + emulate: bool, + linear_dtype: torch.dtype, + ): + if not emulate: + if not torch.cuda.is_available(): + warnings.warn("CUDA not available") + pytest.skip() + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" + ) + pytest.skip() + + m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + emulate=emulate, + ) + m = Float8Linear.from_float(copy.deepcopy(m_ref), config) + + # autocast off + x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" + + # autocast on + with torch.autocast("cuda"): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" + + with torch.autocast("cuda", dtype=torch.bfloat16): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert ( + y.dtype == torch.bfloat16 + ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" + + @pytest.mark.parametrize( + "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): + emulate = ( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) + ) + + m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + config = Float8LinearConfig(emulate=emulate) + m = Float8Linear.from_float(copy.deepcopy(m), config) + + # Cast the module to dtype + m = m.to(dtype=linear_dtype) + if linear_requires_sync(config): + # Check amax buffer types + for key in [ + "fp8_amax_input", + "fp8_amax_history_input", + "fp8_scale_input", + "fp8_amax_weight", + "fp8_amax_history_weight", + "fp8_scale_weight", + "fp8_amax_grad_output", + "fp8_amax_history_grad_output", + "fp8_scale_grad_output", + ]: + assert ( + m._buffers[key].dtype == torch.float32 + ), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32" + + # autocast off + x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" + + # autocast on + with torch.autocast("cuda"): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" + + with torch.autocast("cuda", dtype=torch.bfloat16): + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert ( + y.dtype == torch.bfloat16 + ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" + + def test_repr(self): + m = nn.Linear(32, 16) + config = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + emulate=True, + ) + m = Float8Linear.from_float( + copy.deepcopy(m), + config=config, + ) + s = m.__repr__() + assert "i:dyn,w:del,go:dyn" in s + + +class TestScaledMM: + @unittest.skipIf( + not is_H100, + "CUDA not available", + ) + @pytest.mark.parametrize( + "base_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("use_fast_accum", [True, False]) + def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): + torch.manual_seed(42) + input_dtype = e4m3_dtype + output_dtype = base_dtype + compare_type = torch.float32 + + a = torch.randn(16, 16, device="cuda", dtype=base_dtype) + b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + a_scale = tensor_to_scale(a, input_dtype).float() + b_scale = tensor_to_scale(b, input_dtype).float() + + a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype) + b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype) + + out_scaled_mm = addmm_float8_unwrapped( + a_fp8._data, + a_fp8._scale, + b_fp8._data, + b_fp8._scale, + output_dtype=output_dtype, + use_fast_accum=use_fast_accum, + ) + out_emulated = torch.ops.aten.mm_float8_emulated( + a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype + ) + + if output_dtype != base_dtype: + out_scaled_mm = out_scaled_mm.to(compare_type) + out_emulated = out_emulated.to(compare_type) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + @unittest.skipIf(not is_H100, "CUDA not available") + def test_different_configs_error(self): + x_fp32 = torch.randn(16, 16, device="cuda") + x_scale = torch.tensor(1.0, device="cuda") + fp8_dtype = e4m3_dtype + linear_config_a = LinearMMConfig( + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ) + linear_config_b = LinearMMConfig( + ScaledMMConfig(True, True, False, False), + ScaledMMConfig(True, False, False, False), + ScaledMMConfig(True, False, False, False), + ) + a = hp_tensor_and_scale_to_float8( + x_fp32, + x_scale, + fp8_dtype, + linear_config_a, + GemmInputRole.INPUT, + ) + b = hp_tensor_and_scale_to_float8( + x_fp32, + x_scale, + fp8_dtype, + linear_config_b, + GemmInputRole.WEIGHT, + ) + with pytest.raises( + AssertionError, + match="linear_mm_config.output mismatch", + ): + a @ b + + @unittest.skipIf( + not is_H100, + "CUDA not available", + ) + @pytest.mark.parametrize( + "base_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("use_fast_accum", [True, False]) + def test_pad_inner_dim(self, base_dtype, use_fast_accum): + torch.manual_seed(42) + input_dtype = torch.float8_e4m3fn + compare_type = torch.float32 + + a = torch.randn(16, 41, device="cuda", dtype=base_dtype) + b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + + a_scale = tensor_to_scale(a, input_dtype).float() + b_scale = tensor_to_scale(b, input_dtype).float() + + a_fp8 = hp_tensor_and_scale_to_float8( + a, a_scale, input_dtype, None, GemmInputRole.INPUT + ) + b_fp8 = hp_tensor_and_scale_to_float8( + b, b_scale, input_dtype, None, GemmInputRole.WEIGHT + ) + + with pytest.raises( + RuntimeError, + match=re.escape( + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41." + ), + ): + a_fp8 @ b_fp8 + + scaled_mm_config = ScaledMMConfig(False, use_fast_accum, False, True) + pad_config = LinearMMConfig( + scaled_mm_config, scaled_mm_config, scaled_mm_config + ) + + a_fp8 = hp_tensor_and_scale_to_float8( + a, + a_scale, + input_dtype, + pad_config, + GemmInputRole.INPUT, + ) + b_fp8 = hp_tensor_and_scale_to_float8( + b, + b_scale, + input_dtype, + pad_config, + GemmInputRole.WEIGHT, + ) + out_padded = a_fp8 @ b_fp8 + out_padded.to(compare_type) + + emulated_scaled_mm_config = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_config = LinearMMConfig( + emulated_scaled_mm_config, + emulated_scaled_mm_config, + emulated_scaled_mm_config, + ) + a_fp8 = hp_tensor_and_scale_to_float8( + a, + a_scale, + input_dtype, + emulated_config, + GemmInputRole.INPUT, + ) + b_fp8 = hp_tensor_and_scale_to_float8( + b, + b_scale, + input_dtype, + emulated_config, + GemmInputRole.WEIGHT, + ) + out_emualted = a_fp8 @ b_fp8 + out_emualted.to(compare_type) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) + + +class TestNumerics: + @pytest.mark.parametrize( + "float8_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ], + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_small_amax_float16(self, float8_dtype): + # If we calculate scale naively with FP8_MAX_POS / amax, + # the result may not be representable in fp16. Verify that + # the way we calculate scales actually works for tensors with + # small values. + # + # naive_s = fp8_max_pos / (amax + eps) + # + # failing case: + # + # fp8_max_pos / (amax + eps) >= fp16_max_pos, or + # + # amax + eps >= fp8_max_pos / fp16_max_pos + + float8_max_pos = torch.finfo(float8_dtype).max + FP16_MAX_POS = torch.finfo(torch.float16).max + + target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) + x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") + scale = tensor_to_scale(x, float8_dtype) + assert not torch.any(torch.isinf(scale)) + + +class TestFloat8LinearUtils(unittest.TestCase): + def test_swap_root_linear(self): + for emulate in [True, False]: + module = nn.Linear(3, 3) + config = Float8LinearConfig(emulate=emulate) + module = convert_to_float8_training(module, config=config) + self.assertIsInstance(module, Float8Linear) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) + + def test_swap_root_linear_with_children_raises(self): + for emulate in [True, False]: + module = nn.Linear(3, 3) + module.child = nn.Sequential(nn.Linear(3, 3)) + config = Float8LinearConfig(emulate=emulate) + with self.assertRaisesRegex( + AssertionError, + "Does not support a root nn.Linear with children", + ): + convert_to_float8_training(module, config=config) + + def test_swap_submodule_linears(self): + class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin1 = nn.Linear(dim, 4 * dim) + self.lin2 = nn.Linear(4 * dim, dim) + + for emulate in [True, False]: + model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) + config = Float8LinearConfig(emulate=emulate) + model = convert_to_float8_training(model, config=config) + self.assertIsInstance(model[0].lin1, Float8Linear) + self.assertIsInstance(model[0].lin2, Float8Linear) + self.assertIsInstance(model[1], Float8Linear) + self.assertIsInstance(model[2].lin1, Float8Linear) + self.assertIsInstance(model[2].lin2, Float8Linear) + + def test_swap_linears_with_filters(self): + class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin1 = nn.Linear(dim, 4 * dim) + self.lin2 = nn.Linear(4 * dim, 4 * dim) + + model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40)) + # filter out the linear layers whose shape is smaller than 32 or non-divisible by 16. + + size_limit = 32 + + def module_filter_fn(mod, fqn): + return ( + mod.in_features >= size_limit + and mod.out_features >= size_limit + and mod.in_features % 16 == 0 + and mod.out_features % 16 == 0 + ) + + config = Float8LinearConfig(emulate=True) + model = convert_to_float8_training( + model, + config=config, + module_filter_fn=module_filter_fn, + ) + # in_features=8, out_features=32, 8 is less than 32. + self.assertNotIsInstance(model[0].lin1, Float8Linear) + # in_features=32, out_features=32, + self.assertIsInstance(model[0].lin2, Float8Linear) + # in_features=32, out_features=32, + self.assertIsInstance(model[1], Float8Linear) + # in_features=40, out_features=160, 40 is not divisible by 16. + self.assertNotIsInstance(model[2].lin1, Float8Linear) + # in_features=160, out_features=160, + self.assertIsInstance(model[2].lin2, Float8Linear) + + def test_swap_submodule_linears_with_skip(self): + class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin1 = nn.Linear(dim, 4 * dim) + self.lin2 = nn.Linear(4 * dim, dim) + + model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) + module_filter_fn = lambda mod, fqn: fqn not in [ + "0.lin2", + "2.lin1", + ] + config = Float8LinearConfig(emulate=True) + model = convert_to_float8_training( + model, + config=config, + module_filter_fn=module_filter_fn, + ) + self.assertTrue(type(model[0].lin1) is Float8Linear) + self.assertTrue(type(model[0].lin2) is nn.Linear) + self.assertTrue(type(model[1]) is Float8Linear) + self.assertTrue(type(model[2].lin1) is nn.Linear) + self.assertTrue(type(model[2].lin2) is Float8Linear) + + def test_fp8_tensor_statistics(self): + hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) + lp_dtypes = (e4m3_dtype, e5m2_dtype) + for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): + x1_hp = torch.ones(4, 4, dtype=hp_dtype) + tensor_len = x1_hp.numel() + + # Overflow caused by a too large scaling factor + s_overflow = torch.tensor(1e9) + fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype) + (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype) + self.assertEqual((zero_cnt, max_cnt), (0, tensor_len)) + + # Underflow caused by a too small scaling factor + s_underflow = torch.tensor(1e-9) + fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype) + (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype) + self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0)) + + # Both overflow and underflow + x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0) + fp8_over_underflow = hp_tensor_and_scale_to_float8( + x2_hp, torch.tensor(1.0), lp_dtype + ) + (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) + self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len)) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py new file mode 100644 index 000000000..1f3ebe169 --- /dev/null +++ b/test/float8/test_compile.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import copy +import random +import sys +import unittest +from io import StringIO + +import pytest + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +import torch +import torch.nn as nn +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + get_float8_layers, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed +from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_utils import e4m3_dtype + +from torch._dynamo.test_case import TestCase as DynamoTestCase +from torch._dynamo.testing import CompileCounterWithBackend + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +def _test_compile_base( + backend: str, + fullgraph: bool, + config: Float8LinearConfig, + dtype: torch.dtype, +): + random.seed(0) + torch.manual_seed(0) + x_shape = (16, 16) + linear_dtype = torch.bfloat16 + + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) + + m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) + m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + y_ref = m_ref(x) + y_ref.sum().backward() + torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) + torch.testing.assert_close( + m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 + ) + torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) + + +@pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize( + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +def test_eager_only( + fullgraph, + emulate: bool, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + dtype: torch.dtype, +): + torch._dynamo.reset() + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + emulate=emulate, + ) + _test_compile_base( + "eager", + fullgraph, + config, + dtype, + ) + + +@pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) +@pytest.mark.parametrize( + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +def test_aot_eager( + fullgraph, + emulate: bool, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + dtype: torch.dtype, +): + torch._dynamo.reset() + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + emulate=emulate, + ) + _test_compile_base( + "aot_eager", + fullgraph, + config, + dtype, + ) + + +@pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize("emulate", [False]) +@pytest.mark.parametrize( + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@pytest.mark.parametrize( + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] +) +@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_inductor( + fullgraph, + emulate: bool, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + dtype: torch.dtype, +): + torch._dynamo.reset() + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + emulate=emulate, + ) + _test_compile_base( + "inductor", + fullgraph, + config, + dtype, + ) + + +class TestGraphBreaks(DynamoTestCase): + class MockLinear(torch.nn.Module): + def __init__(self, graph_break: bool): + super().__init__() + self.register_buffer("fp8_amax_x", torch.tensor(1.0)) + self.register_buffer("fp8_scale_x", torch.tensor(1.0)) + self.graph_break = graph_break + + def forward(self, x): + x_fp8 = hp_tensor_to_float8_delayed( + x, + self.fp8_scale_x, + e4m3_dtype, + self.fp8_amax_x, + LinearMMConfig(), + ) + if self.graph_break: + torch._dynamo.graph_break() + x_hp = x_fp8.to_original_precision() + return x_hp + return x_fp8 + + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + def test_float8_with_graph_break_in_the_middle(self): + """Test that having Float8Tensor object at the boundary of a subgraph""" + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=True).cuda() + compiled_mod = copy.deepcopy(mod) + compiled_mod = torch.compile(compiled_mod, backend=cnts) + x = torch.randn(16, 16, device="cuda") + y_eager = mod(x) + y_compiled = compiled_mod(x) + self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") + torch.testing.assert_close(y_eager, y_compiled) + + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + def test_float8_graph_input(self): + """Test that having Float8Tensor object as a graph input""" + + def to_float(x): + return x.to_original_precision() + + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=False).cuda() + x = torch.randn(2, 2, device="cuda") + compiled_to_float = torch.compile(to_float, backend=cnts) + y = mod(x) + y2_eager = to_float(y) + y2_compiled = compiled_to_float(y) + self.assertEqual( + cnts.frame_count, + 1, + "to_float was not compiled into 1 frame and likely encountered a skip!", + ) + torch.testing.assert_close(y2_eager, y2_compiled) + + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + def test_float8_graph_output(self): + """Test that having Float8Tensor object as a graph output works""" + cnts = CompileCounterWithBackend("inductor") + mod = self.MockLinear(graph_break=False).cuda() + compiled_mod = torch.compile(mod, backend=cnts) + x = torch.randn(16, 16, device="cuda") + y_compiled = compiled_mod(x) + + self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") + tensors, ctx = y_compiled.__tensor_flatten__() + for tensor in tensors: + assert not isinstance( + getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor + ), "Float8Tensor should not contain any FakeTensors!" + assert isinstance( + y_compiled._orig_dtype, torch.dtype + ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( + type(y_compiled._orig_dtype) + ) + assert isinstance( + y_compiled._linear_mm_config.output.emulate, bool + ), "Float8Tensor._emulate should be a bool but got {}".format( + type(y_compiled._linear_mm_config.output.emulate) + ) + + +@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") +def test_sync_amax_func(): + torch._dynamo.reset() + cnts = CompileCounterWithBackend("inductor") + module = torch.nn.Sequential( + nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) + ) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + ) + float8_mod = convert_to_float8_training( + module, + config=config, + ) + compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) + compiled_swap_func(float8_mod) + assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" + + +class capture_stderr(list): + """ + Replace sys.stderr with a temporary StringIO + """ + + def __enter__(self): + self.sys_stderr = sys.stderr + self.stringio = StringIO() + sys.stderr = self.stringio + return self + + def __exit__(self, *args): + self.append(str(self.stringio.getvalue())) + del self.stringio + sys.stderr = self.sys_stderr + + +@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") +def test_sync_amax_func_cuda_graph_success(): + torch._dynamo.reset() + with capture_stderr() as stderr: + my_module = nn.Sequential( + nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) + ).to("cuda") + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + ) + convert_to_float8_training( + my_module, + config=config, + ) + inpt = torch.randn( + 16, 16, device="cuda", dtype=torch.float32, requires_grad=True + ) + sync_func = torch.compile( + sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True + ) + fp8_layers = get_float8_layers(my_module) + my_module(inpt) + sync_func(my_module, fp8_layers) + + assert "skipping cudagraphs due to mutaton on input" not in stderr[0] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py new file mode 100644 index 000000000..7a6c9125d --- /dev/null +++ b/test/float8/test_dtensor.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of manually defined float16 TP vs float8 TP of toy models + +Note: for now, this does not run in CI. +TODO(future): make this run in CI +""" + +import copy +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import pytest + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training + +from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, +) +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) +from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale +from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.parallel import parallelize_module +from tqdm import tqdm + + +def setup_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", -1)) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + return device_mesh + + +class FeedForward(nn.Module): + """MLP based model""" + + def __init__(self): + super(FeedForward, self).__init__() + self.w1 = nn.Linear(16, 32, bias=False) + self.w2 = nn.Linear(16, 32, bias=False) + self.out_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.ffn = FeedForward() + + def forward(self, x): + return self.ffn(x) + + +def _test_scaled_mm(mesh: DeviceMesh, size=16): + device = mesh.device_type + fp8_dtype = e4m3_dtype + world_size = mesh.size() + + x_fp32 = torch.rand(size, size, device=device) + y_fp32 = torch.eye(size, device=device).t() + + placement_combs = ( + (Shard(0), Replicate()), + (Replicate(), Shard(1)), + (Shard(1), Shard(0)), + ) + expected_dt_out_shape = ( + (size * world_size, size), + (size, size * world_size), + (size, size), + ) + for idx, (lhs_placement, rhs_placement) in enumerate(placement_combs): + x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() + y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() + + x_fp8 = hp_tensor_and_scale_to_float8( + x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT + ) + y_fp8 = hp_tensor_and_scale_to_float8( + y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT + ) + + dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) + dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) + + assert isinstance(dist_x_fp8.to_local(), Float8Tensor) + assert isinstance(dist_y_fp8.to_local(), Float8Tensor) + assert dist_x_fp8.to_local()._orig_dtype == torch.float32 + out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8) + local_fp8_out = out_fp8.to_local() + assert out_fp8.shape == expected_dt_out_shape[idx], (idx, local_fp8_out.shape) + + # after mm the out dtype should be fp32 + assert local_fp8_out.dtype == torch.float32 + + +def _test_fp8_redistribute(mesh: DeviceMesh, size=16): + device = mesh.device_type + fp8_dtype = e4m3_dtype + world_size = mesh.size() + + x_fp32 = torch.rand(size, size, device=device) + + x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() + + x_fp8 = hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype) + + dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False) + out_dist = dist_x_fp8.redistribute(placements=[Replicate()]) + assert out_dist.shape == (size * world_size, size) + assert out_dist.placements == (Replicate(),) + out_local = out_dist.to_local() + # after allgather the out shape should be replicate + assert out_local.shape == (size * world_size, size) + from torch.distributed._functional_collectives import AsyncCollectiveTensor + + if isinstance(out_local, AsyncCollectiveTensor): + out_local = out_local.wait() + + assert isinstance(out_local, Float8Tensor) + assert out_local._data.dtype == fp8_dtype + + +def _test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): + device = mesh.device_type + fp8_dtype = e4m3_dtype + + x_fp32 = torch.rand(size, size, device=device) + dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) + + dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() + assert isinstance(dist_x_scale, DTensor) + + dist_x_fp8 = hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) + assert isinstance(dist_x_fp8, DTensor) + + +def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): + device = mesh.device_type + fp8_dtype = e4m3_dtype + + x_fp32 = torch.rand(size, size, device=device, requires_grad=True) + local_weight = torch.rand(2 * size, size, device=device, requires_grad=True) + target = torch.rand(size, 2 * size, device=device) + + dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) + dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() + + dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)]) + dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() + dist_target = distribute_tensor(target, mesh, [Shard(0)]) + + dist_x_fp8 = hp_tensor_and_scale_to_float8( + dist_x_fp32, + dist_x_scale, + fp8_dtype, + None, + GemmInputRole.INPUT, + ) + dist_weight_fp8 = hp_tensor_and_scale_to_float8( + dist_wight_fp32, + dist_weight_scale, + fp8_dtype, + None, + GemmInputRole.WEIGHT, + ) + + out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) + out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) + assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" + loss = torch.sum(torch.abs(out - dist_target)) + loss.backward() + + +def _test_fp8_mlp_tensor_parallelism_base( + mesh: DeviceMesh, size=16, compile: bool = False +): + device = mesh.device_type + # For now, only supports dynamic scaling of `x` and `dL_dY`. + # TODO(future): add support for float8 all-gather with delayed scaling + # for activations and gradients. + config = Float8LinearConfig(emulate=True) + + toy_model = ToyModel().to(device) + toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + sp_model = copy.deepcopy(toy_model) + sp_model = convert_to_float8_training(sp_model, config=config) + + # vanilla TP + tp_model = parallelize_module( + tp_model, + mesh, + { + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel(), + }, + ) + + # "sequence parallel" mlp computation + sp_model = parallelize_module( + sp_model, + mesh, + { + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + # PrepareFloat8ModuleInput with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = convert_to_float8_training(sp_model2, config=config) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + sp_out = sp_model(x_fp32_sp_input) + sp_out.sum().backward() + global_out = toy_model_fp8(x_fp32) + global_out.sum().backward() + torch.testing.assert_close(tp_out, global_out) + torch.testing.assert_close(sp_out.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + ) + + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) + torch.testing.assert_close( + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad + ) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + ) + + +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) + + +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + + +if __name__ == "__main__": + # float8 only works on CUDA H100 so we only test cuda and we follow + # other test files to not use TestCase but instead just add the test + # cases in the main func. + device_mesh = setup_distributed() + tests = [ + _test_scaled_mm, + _test_fp8_redistribute, + _test_dtensor_cast_to_fp8, + _test_dtensor_fp8_autograd, + _test_fp8_mlp_tensor_parallelism_eager, + _test_fp8_mlp_tensor_parallelism_compile, + ] + + for test in tqdm(tests, desc="Running tests"): + try: + test(device_mesh) + except Exception as e: + print(f"Test {test.__name__} failed with error: {e}") + raise e + + torch.distributed.destroy_process_group() diff --git a/test/float8/test_dtensor.sh b/test/float8/test_dtensor.sh new file mode 100755 index 000000000..2e38feffe --- /dev/null +++ b/test/float8/test_dtensor.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# terminate script on first error +set -e + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_dtensor.sh because no CUDA devices are available." + exit +fi + +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py diff --git a/test/float8/test_everything.sh b/test/float8/test_everything.sh new file mode 100755 index 000000000..d70833323 --- /dev/null +++ b/test/float8/test_everything.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# terminate script on first error +set -e +IS_ROCM=$(rocm-smi --version || true) + +pytest test/float8/test_base.py +pytest test/float8/test_compile.py +pytest test/float8/test_inference_flows.py +pytest test/float8/test_numerics_integration.py + +# These tests do not work on ROCm yet +if [ -z "$IS_ROCM" ] +then +./test/float8/test_fsdp.sh +./test/float8/test_fsdp_compile.sh +./test/float8/test_dtensor.sh +pytest test/float8/test_fsdp2/test_fsdp2.py +fi + +echo "all tests successful" diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py new file mode 100644 index 000000000..f30878b33 --- /dev/null +++ b/test/float8/test_fsdp.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of bf16 versus float8 with FSDP on. At a high level: +1. start with a reference model, with FSDP on +2. run forward + backward + optim for 2 iterations +3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling) +4. compare outputs and state dict between (2) and (3), should be close +""" + +import copy +import os +import pytest +import warnings + +import fire + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_utils import compute_error +from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) + +torch.manual_seed(0) + +B, M, K, N = 8, 8, 32, 32 +lr = 0.01 +N_ITER = 2 + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def get_model(K, N, base_dtype=torch.float32): + m = nn.Sequential( + nn.Linear(K, N, dtype=base_dtype), + nn.ReLU(), + nn.Linear(N, N, dtype=base_dtype), + nn.ReLU(), + ) + return m + + +# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html +# and modified +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + torch.cuda.set_device(rank) + + emulate, base_dtype, compile, use_weight_dynamic_scaling = args + model = get_model(K, N, base_dtype=base_dtype).to(rank) + model_fp8 = copy.deepcopy(model) + + scaling_type_weight = ( + ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED + ) + config = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + # TODO(future): delete this arg as it's always False + emulate=False, + ) + + # Note: we only iterate over `scaling_type_weight` because FSDP only interacts + # with weights. + convert_to_float8_training( + model_fp8, + config=config, + ) + + # To compile FSDP, we need use_orig_params to True + model = FSDP(model, use_orig_params=True) + model_fp8 = FSDP(model_fp8, use_orig_params=True) + # TODO: The following line doesn't work. We should fix it. + # model = FSDP(torch.compile(model), use_orig_params=True) + + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + optimizer_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) + + # Note: we need two different inputs to properly measure the impact of + # delayed scaling, before the first input uses dynamic scaling to + # populate the buffers + ref_input_global = [ + torch.randn(B, M, K).cuda().to(base_dtype), + torch.randn(B, M, K).cuda().to(base_dtype), + ] + ref_grad_global = [ + torch.randn(B, M, N).cuda().to(base_dtype), + torch.randn(B, M, N).cuda().to(base_dtype), + ] + ref_input_local = [] + ref_grad_local = [] + + # basic distributed data sampling + assert B % world_size == 0 + bsz_local_start = int(rank / world_size * B) + bsz_local_end = int((rank + 1) / world_size * B) + for idx in range(N_ITER): + ref_input_local.append( + ref_input_global[idx][bsz_local_start:bsz_local_end].to(rank) + ) + ref_grad_local.append( + ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank) + ) + + sync_float8_func = sync_float8_amax_and_scale_history + if compile: + sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) + + def forward_backward(model, optim, is_fp8, i): + optim.zero_grad() + y_local = model(ref_input_local[i]) + y_local.backward(ref_grad_local[i]) + if is_fp8 and linear_requires_sync(config): + sync_float8_func(model) + optim.step() + return y_local + + for i in range(N_ITER): + # We first run one iteration without compile, as a workaround to compile float8 layer. + # In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False" + # After that, float8 layers go the the branches of "self.is_amax_initialized == True" + # TODO: Need to fix compile to run wihtout this workaround. + if i == 1 and compile: + model = torch.compile(model) + model_fp8 = torch.compile(model_fp8) + y_local = forward_backward(model, optimizer, is_fp8=False, i=i) + y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i) + local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841 + + # get global y + y_global = [ + torch.zeros(*y_local.shape, dtype=base_dtype).to(rank) + for r in range(world_size) + ] + dist.all_gather(y_global, y_local) + y_global = torch.cat(y_global, dim=0) + y_global_fp8 = [ + torch.zeros(*y_local_fp8.shape, dtype=base_dtype).to(rank) + for r in range(world_size) + ] + dist.all_gather(y_global_fp8, y_local_fp8) + y_global_fp8 = torch.cat(y_global_fp8, dim=0) + if rank == 0: + sqnr = compute_error(y_global, y_global_fp8) + assert sqnr > 15.0, f"SQNR of {sqnr} is too low" + + # get global state dict + # https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html + dist.barrier() + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + cpu_state = model.state_dict() + with FSDP.state_dict_type(model_fp8, StateDictType.FULL_STATE_DICT, save_policy): + cpu_state_fp8 = model_fp8.state_dict() + if rank == 0: + for k, v1 in cpu_state.items(): + v2 = cpu_state_fp8[k] + v1, v2 = v1.cpu(), v2.cpu() + sqnr = compute_error(v1, v2) + assert sqnr > 15.0, f"SQNR of {sqnr} is too low, k: {k}, v1: {v1}, v2: {v2}" + + cleanup() + + +def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): + base_dtype = torch.bfloat16 + + emulate = False + if not torch.cuda.is_available(): + warnings.warn("CUDA not available, running in emulation_mode") + emulate = True + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode" + ) + emulate = True + + WORLD_SIZE = torch.cuda.device_count() + args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling) + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/test/float8/test_fsdp.sh b/test/float8/test_fsdp.sh new file mode 100755 index 000000000..3ff19d917 --- /dev/null +++ b/test/float8/test_fsdp.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# terminate script on first error +set -e + +launch() { + echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING" + + # the NCCL_DEBUG setting is to avoid log spew + # the CUDA_VISIBLE_DEVICES setting is for easy debugging + NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \ + --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING + + echo "✅ All Tests Passed ✅" +} + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_fsdp.sh because no CUDA devices are available." + exit +fi + +# COMPILE, USE_WEIGHT_DYNAMIC_SCALING +for i in False,False False,True True,False True,True +do + IFS=","; set -- $i; + COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2 + launch +done diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/test/float8/test_fsdp2/fsdp2_common.py new file mode 100644 index 000000000..333206ba4 --- /dev/null +++ b/test/float8/test_fsdp2/fsdp2_common.py @@ -0,0 +1,89 @@ +import contextlib +from typing import List, Optional + +import torchao.float8.config as config + +import torch +import torch.distributed as dist +import torch.nn as nn +from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + + +def check_parity_no_mp( + test_cls, + ref_model: nn.Module, + ref_optim: torch.optim.Optimizer, + fsdp_model: nn.Module, + fsdp_optim: torch.optim.Optimizer, + local_inp: torch.Tensor, + precompute: bool = False, + config: Optional[Float8LinearConfig] = None, + compile_transformer_block: bool = False, +): + # TODO(before land): reorder args and make config not optional + for iter_idx in range(10): + losses: List[torch.Tensor] = [] + for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + losses.append(model(local_inp).sum()) + losses[-1].backward() + if model is ref_model: + for param in model.parameters(): + dist.all_reduce(param.grad) + param.grad.div_(dist.get_world_size()) + + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(model) + + optim.step() + if ( + model is fsdp_model + and precompute + and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC + ): + precompute_float8_dynamic_scale_for_fsdp(model) + + if compile_transformer_block: + test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) + else: + test_cls.assertEqual(losses[0], losses[1]) + + +def check_parity_bf16_mp( + test_cls, + ref_model: nn.Module, + ref_model_bf16: nn.Module, + ref_optim: torch.optim.Optimizer, + fsdp_model: nn.Module, + fsdp_optim: torch.optim.Optimizer, + local_inp: torch.Tensor, +): + for iter_idx in range(10): + losses: List[torch.Tensor] = [] + for model, optim in ( + (ref_model_bf16, ref_optim), + (fsdp_model, fsdp_optim), + ): + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + losses.append(model(local_inp).sum()) + losses[-1].backward() + if model is ref_model_bf16: + for param_bf16, param_fp32 in zip( + ref_model_bf16.parameters(), ref_model.parameters() + ): + dist.all_reduce(param_bf16.grad) + param_bf16.grad.div_(dist.get_world_size()) + param_fp32.grad = param_bf16.grad.float() + param_bf16.grad = None + # TODO(future): add amax syncing once delayed scaling is supported + optim.step() + for param_fp32, param_bf16 in zip( + ref_model.parameters(), ref_model_bf16.parameters() + ): + param_bf16.detach().copy_(param_fp32) + test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py new file mode 100644 index 000000000..7004b3a1c --- /dev/null +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -0,0 +1,561 @@ +import copy +import itertools +import pytest +import threading +import unittest +from typing import Any, List + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +import torch +import torch._dynamo.testing +import torch.distributed as dist +import torch.nn as nn +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor +from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._tensor import DTensor +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import ( + FSDPTest, + FSDPTestMultiThread, + MLP, + patch_all_gather, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, +) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +if not is_H100: + pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) + +class TestFloat8Common: + def broadcast_module(self, module: nn.Module) -> None: + # Broadcast for multi-threaded process group tests since seed is per + # process, not per thread + for param in module.parameters(): + dist.broadcast(param, src=0) + + def init_single_module(self) -> nn.Module: + torch.manual_seed(42) + module = nn.Linear(16, 16, device="cuda") + self.broadcast_module(module) + return module + + def init_multi_module(self) -> nn.Module: + torch.manual_seed(42) + module = nn.Sequential(*[MLP(16, device="cuda") for _ in range(3)]) + self.broadcast_module(module) + return module + + def init_transformer(self, weight_tying: bool) -> nn.Module: + torch.manual_seed(42) + args = ModelArgs( + n_layers=3, + dim=768, + n_heads=12, + dropout_p=0.0, + weight_tying=weight_tying, + vocab_size=32, + ) + module = Transformer(args).cuda() + self.broadcast_module(module) + return module + + def get_local_inp(self, dtype: torch.dtype = torch.float32): + torch.manual_seed(42) + global_inp = torch.randn((16 * self.world_size, 16), device="cuda", dtype=dtype) + dist.broadcast(global_inp, src=0) + return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) + + +class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 2) + + @skip_if_lt_x_gpu(2) + def test_transformer_parity(self): + self.run_subtests( + { + "enable_fsdp_float8_all_gather": [False, True], + "precompute": [False, True], + "scaling_type_weight": [ + ScalingType.DYNAMIC, + ScalingType.DELAYED, + ], + "compile_transformer_block": [False, True], + }, + self._test_transformer_parity, + ) + + def _test_transformer_parity( + self, + enable_fsdp_float8_all_gather: bool, + precompute: bool, + scaling_type_weight: ScalingType, + compile_transformer_block: bool, + ): + if not enable_fsdp_float8_all_gather and precompute: + return + elif scaling_type_weight is ScalingType.DELAYED and precompute: + return + + # NOTE: Weight-tying does not compose with fp8 all-gather because the + # embedding weight and output linear weight are tied but only the + # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to + # fp8 for that tied weight, incorrectly using fp8 for the embedding. + weight_tying = not enable_fsdp_float8_all_gather + module = self.init_transformer(weight_tying=weight_tying).cuda() + ref_module = copy.deepcopy(module) + float8_linear_config1 = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + convert_to_float8_training( + ref_module, + config=float8_linear_config1, + ) + if compile_transformer_block: + for layer_id, transformer_block in ref_module.layers.named_children(): + transformer_block = torch.compile(transformer_block, dynamic=False) + ref_module.layers.register_module(layer_id, transformer_block) + float8_linear_config2 = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + convert_to_float8_training( + module, + config=float8_linear_config2, + ) + for layer_id, transformer_block in module.layers.named_children(): + if compile_transformer_block: + transformer_block = torch.compile(transformer_block, dynamic=False) + fully_shard(transformer_block) + module.layers.register_module(layer_id, transformer_block) + fully_shard(module) + ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) + optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) + local_inp = torch.randint( + 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" + ) + check_parity_no_mp( + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + precompute, + config=float8_linear_config2, + compile_transformer_block=compile_transformer_block, + ) + + @skip_if_lt_x_gpu(2) + def test_transformer_memory(self): + """Tests peak active memory in the forward and backward passes.""" + for enable_fsdp_float8_all_gather in [False, True]: + self._test_transformer_memory(enable_fsdp_float8_all_gather) + + def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool): + torch.manual_seed(42) + # Pre-run a linear forward (gemm and bias) and backward (gemm) to + # allocate the cuBLAS workspaces before measuring the memory usage + # since the workspace size can differ between hardwares + lin = torch.nn.Linear(768, 768, device="cuda") + inp = torch.randn(1, 768, device="cuda") + lin(inp).sum().backward() + torch.cuda.empty_cache() + base_mem_mb = self._get_peak_active_memory_mb() + + vocab_size = 32 + model_args = ModelArgs( + vocab_size=vocab_size, + n_layers=3, + dim=768, + n_heads=12, + weight_tying=False, + ) + model = Transformer(model_args) + # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility + # requirement to use a smaller activation size + float8_linear_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + emulate=True, + ) + convert_to_float8_training(model, config=float8_linear_config) + model_unsharded_numel = sum(p.numel() for p in model.parameters()) + model_sharded_numel = (model_unsharded_numel + 1) // 2 + block_lin_weight_numel = 0 + block_other_numel = 0 + for module in model.layers[0].modules(): + for param in module.parameters(recurse=False): + if isinstance(module, nn.Linear): + block_lin_weight_numel += param.numel() + else: + block_other_numel += param.numel() + non_block_numel = round( + sum(p.numel() for p in model.tok_embeddings.parameters()) + + sum(p.numel() for p in model.pos_embeddings.parameters()) + + sum(p.numel() for p in model.norm.parameters()) + + sum(p.numel() for p in model.output.parameters()) + ) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module) + fully_shard(model) + + # Init: Each module is moved to GPU before sharding parameters + peak_mem_mb = self._get_peak_active_memory_mb() + curr_mem_mb = self._get_curr_active_memory_mb() + init_mem_mb = ( + (model_sharded_numel + block_lin_weight_numel + block_other_numel) * 4 / 1e6 + ) + # Allow for some buffer for the peak memory since original parameters + # are not freed until a `fully_shard` call returns + buffer_mb = 4 + self.assertLessEqual(peak_mem_mb - base_mem_mb, init_mem_mb + buffer_mb) + self.assertLessEqual(curr_mem_mb - base_mem_mb, init_mem_mb) + + # Use a small input to minimize activation memory usage + inp = torch.randint(0, vocab_size, (1, 4), device="cuda") + + # Forward: + loss = model(inp) + mem_mb = self._get_peak_active_memory_mb() + # Allow for some buffer for fragmentation/activations (where this + # number is kept much smaller than the actual memory usage, which is on + # the order of 100-200+ MB) + buffer_mb = 16 + if enable_fsdp_float8_all_gather: + # Non-block parameters (fp32), 3x block non-linear-weight + # parameters (fp32) and block linear-weight parameters (fp8) + # (current all-gather, copy-out, and next all-gather), and other + expected_mem_mb = ( + (non_block_numel * 4) + + 3 * (block_lin_weight_numel + block_other_numel * 4) + ) / 1e6 + buffer_mb + else: + # Non-block parameters (fp32), 3x block parameters (fp32) + # (current all-gather, copy-out, and next all-gather), Nx block + # linear-weight parameters (fp8) for N blocks (saved by autograd), + # and other + expected_mem_mb = ( + (non_block_numel + 3 * (block_lin_weight_numel + block_other_numel)) * 4 + + model_args.n_layers * block_lin_weight_numel + ) / 1e6 + buffer_mb + # Sharded parameters + expected_mem_mb += model_sharded_numel * 4 / 1e6 + self.assertLessEqual(mem_mb, expected_mem_mb + base_mem_mb) + + # Backward: + loss.sum().backward() + mem_mb = self._get_peak_active_memory_mb() + if enable_fsdp_float8_all_gather: + # Non-block parameters (fp32), 2x block non-linear weight + # parameters (fp32) and block linear-weight parameters (fp8) + # (current copy-out and next all-gather), 1x block gradients (fp32) + expected_mem_mb = ( + (non_block_numel * 4) + + 2 * (block_lin_weight_numel + block_other_numel * 4) + + 1 * (block_lin_weight_numel + block_other_numel) * 4 + ) / 1e6 + buffer_mb + else: + # Non-block parameters (fp32), 3x block parameters (fp32) (current + # copy-out, next all-gather, current gradients) + expected_mem_mb = ( + non_block_numel + 3 * (block_lin_weight_numel + block_other_numel) * 4 + ) * 4 / 1e6 + buffer_mb + # 2x sharded parameters/gradients + expected_mem_mb += 2 * model_sharded_numel * 4 / 1e6 + self.assertLessEqual(mem_mb, expected_mem_mb + base_mem_mb) + + def _get_peak_active_memory_mb(self) -> int: + mem_stats = torch.cuda.memory_stats() + return round(mem_stats["active_bytes.all.peak"] / 1e6) + + def _get_curr_active_memory_mb(self) -> int: + mem_stats = torch.cuda.memory_stats() + return round(mem_stats["active_bytes.all.current"] / 1e6) + + +class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): + @property + def world_size(self) -> int: + return 2 + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_weight_subclass_dynamic(self): + tensor_cls = WeightWithDynamicFloat8CastTensor + # Check for a single FSDP paramter group + module_fp32 = self.init_single_module() + float8_linear_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=True, + emulate=True, + ) + module = convert_to_float8_training( + module_fp32, + config=float8_linear_config, + ) + self.assertIsInstance(module.weight, tensor_cls) + fully_shard(module) + for param_name, param in module.named_parameters(): + self.assertIsInstance(param, DTensor) + if "weight" in param_name: + self.assertIsInstance(param.to_local(), tensor_cls) + + # Check for multiple FSDP paramter groups + module = self.init_multi_module() + module = convert_to_float8_training( + module, + config=float8_linear_config, + ) + for param_name, param in module.named_parameters(): + if "weight" in param_name: + self.assertIsInstance(param, tensor_cls) + for mlp in module: + fully_shard(mlp) + fully_shard(module) + for param_name, param in module.named_parameters(): + self.assertIsInstance(param, DTensor) + if "weight" in param_name: + self.assertIsInstance(param.to_local(), tensor_cls) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_fp8_fp32_all_gather_dynamic_comm_size(self): + """ + Tests that fp8 all-gather with dynamic scaling communicates the + expected number of bytes. + """ + orig_all_gather = dist.all_gather_into_tensor + all_gather_sizes: List[int] = [] + lock = threading.Lock() + + def all_gather(*args: Any, **kwargs: Any): + nonlocal all_gather_sizes + if len(args) > 0: + output = args[0] + elif "output_tensor" in kwargs: + output = kwargs["output_tensor"] + else: + raise AssertionError( + f"Cannot get all-gather output from\nargs: {args}\nkwargs: {kwargs}" + ) + with lock: + all_gather_sizes.append(output.numel() * output.itemsize) + return orig_all_gather(*args, **kwargs) + + def get_expected_all_gather_size(module: nn.Module): + size = 0 + for param_name, param in module.named_parameters(): + bytes_per_numel = 1 if "weight" in param_name else param.itemsize + size += param.numel() * bytes_per_numel + return size + + # - Check for a single FSDP parameter group + module_fp32 = self.init_single_module() + ref_module = copy.deepcopy(module_fp32) + float8_linear_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=True, + ) + module_fp32 = convert_to_float8_training( + module_fp32, config=float8_linear_config + ) + module = module_fp32 + fully_shard(module) + local_inp = self.get_local_inp() + expected_all_gather_size = get_expected_all_gather_size(ref_module) + with patch_all_gather(all_gather): + out = module(local_inp) + # For MPTG, one rank runs all all-gathers, each of the same size + if all_gather_sizes: + self.assertEqual(len(all_gather_sizes), self.world_size) + self.assertEqual( + all_gather_sizes, [expected_all_gather_size] * self.world_size + ) + all_gather_sizes.clear() + # Force-reshard the module to check the backward all-gather + module.reshard() + with patch_all_gather(all_gather): + out.sum().backward() + if all_gather_sizes: + self.assertEqual(len(all_gather_sizes), self.world_size) + self.assertEqual( + all_gather_sizes, [expected_all_gather_size] * self.world_size + ) + all_gather_sizes.clear() + + # - Check for multiple FSDP parameter groups + module = self.init_multi_module() + ref_module = copy.deepcopy(module) + module = convert_to_float8_training(module, config=float8_linear_config) + for submodule in module: + fully_shard(submodule) + fully_shard(module) + expected_all_gather_sizes = ( + get_expected_all_gather_size(submodule) for submodule in module + ) + with patch_all_gather(all_gather): + out = module(local_inp) + if all_gather_sizes: + self.assertEqual(len(all_gather_sizes), self.world_size * len(module)) + self.assertEqual( + all_gather_sizes, + [s for s in expected_all_gather_sizes for _ in range(self.world_size)], + ) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_fp32_fp8_single_module_parity(self): + """ + Tests numeric parity for fp32 parameters with fp8 computation with a + single module/FSDP communication group. + """ + choices = itertools.product( + [False, True], + [ScalingType.DYNAMIC, ScalingType.DELAYED], + ) + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: + float8_linear_config1 = Float8LinearConfig( + enable_fsdp_float8_all_gather=False, + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + float8_linear_config2 = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + module_fp32 = self.init_single_module() + ref_module = copy.deepcopy(module_fp32) + ref_module = convert_to_float8_training( + ref_module, + config=float8_linear_config1, + ) + ref_module = ref_module.cuda() + module = convert_to_float8_training( + module_fp32, + config=float8_linear_config2, + ) + fully_shard(module) + ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) + optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) + local_inp = self.get_local_inp() + check_parity_no_mp( + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + config=float8_linear_config2, + ) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_fp32_fp8_multi_module_parity(self): + """ + Tests numeric parity for fp32 parameters with fp8 computation with + multiple modules/FSDP communication groups. + """ + choices = itertools.product( + [False, True], + [ScalingType.DYNAMIC, ScalingType.DELAYED], + ) + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: + float8_linear_config1 = Float8LinearConfig( + enable_fsdp_float8_all_gather=False, + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + float8_linear_config2 = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + module = self.init_multi_module().cuda() + ref_module = copy.deepcopy(module) + ref_module = convert_to_float8_training( + ref_module, + config=float8_linear_config1, + ) + module = convert_to_float8_training( + module, + config=float8_linear_config2, + ) + for submodule in module: + fully_shard(submodule) + fully_shard(module) + ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) + optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) + local_inp = self.get_local_inp() + check_parity_no_mp( + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + config=float8_linear_config2, + ) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_bf16_mp_fp8_dynamic_multi_parity(self): + """ + Tests numeric parity for fp32 parameters with FSDP's bf16 mixed + precision and fp8 computation with multiple modules/FSDP communication + groups. Parameters are all-gathered in bf16 before being cast to fp8. + """ + # NOTE: We cannot test easily with fp8 all-gather because then the scale + # is computed using the fp32 sharded parameters, not the bf16 unsharded + # parameters, changing the numerics. + module = self.init_multi_module() + ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16) + float8_config = Float8LinearConfig(emulate=True) + ref_module_bf16 = convert_to_float8_training( + ref_module_bf16, + config=float8_config, + ) + ref_module_fp32 = copy.deepcopy(module).cuda() + module = convert_to_float8_training(module, config=float8_config) + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + for mlp in module: + fully_shard(mlp, mp_policy=mp_policy) + fully_shard(module, mp_policy=mp_policy) + check_parity_bf16_mp( + self, + ref_module_fp32, + ref_module_bf16, + torch.optim.Adam(ref_module_fp32.parameters(), lr=1e-2), + module, + torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True), + self.get_local_inp(torch.bfloat16), + ) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_delayed_scaling_inplace_update(self): + """ + Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace + """ + module = self.init_single_module() + float8_linear_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=True, + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + ) + m_fp8 = convert_to_float8_training( + module, + config=float8_linear_config, + ) + + fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach() + dummy_mesh = None + data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) + self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item()) + + +if __name__ == "__main__": + run_tests() diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py new file mode 100644 index 000000000..f4ca160fd --- /dev/null +++ b/test/float8/test_fsdp_compile.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test autocast + torch.compile + FSDP + Float8Linear +""" + +import os +import warnings + +import fire + +import pytest + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from torchao.float8 import Float8LinearConfig +from torchao.float8.config import CastConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + sync_float8_amax_and_scale_history, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +torch.manual_seed(0) + +B, M, K, N = 8, 8, 32, 32 +lr = 0.01 +N_ITER = 1 + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): + # composability of torch.compile + FSDP + autocast + Float8Linear + # as fo 2023-12-30 + + # without any changes to the Float8Linear, we get this error: + # https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730 + + # if we initialize Float8Linear with is_amax_initialized=True and + # amax_and_scale_synced=True, we get + # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 + # to get around this, we can disable amax init + config = Float8LinearConfig( + enable_amax_init=False, + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + emulate=emulate, + ) + + m = nn.Sequential( + nn.Linear(K, N, dtype=base_dtype), + nn.ReLU(), + ) + convert_to_float8_training( + m, + config=config, + ) + return m + + +# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html +# and modified +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + torch.cuda.set_device(rank) + + (emulate,) = args + + # finally, if we remove the usage of self.bias_dtype, then + # things work e2e. Note that FSDP does not support full-graph compile + # regardless of float8. + + model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( + rank + ) + + # To compile FSDP, we need use_orig_params to True + model = FSDP(model, use_orig_params=True) + + optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) + input_local = torch.randn(B, M, K, N, device="cuda") + sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) + + model = torch.compile(model) + + for _iter in range(N_ITER): + optimizer.zero_grad() + with torch.autocast("cuda"): + y_local = model(input_local) + y_local.sum().backward() + sync_float8_func(model) + optimizer.step() + + print("done!") + cleanup() + + +def run(): + emulate = False + if not torch.cuda.is_available(): + warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2) + emulate = True + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode", + stacklevel=2, + ) + emulate = True + + WORLD_SIZE = torch.cuda.device_count() + args = (emulate,) + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/test/float8/test_fsdp_compile.sh b/test/float8/test_fsdp_compile.sh new file mode 100755 index 000000000..666136aba --- /dev/null +++ b/test/float8/test_fsdp_compile.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# terminate script on first error +set -e +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_fsdp_compile.sh because no CUDA devices are available." + exit +fi + +# Code to be executed if CUDA devices are available +NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py diff --git a/test/float8/test_inference_flows.py b/test/float8/test_inference_flows.py new file mode 100644 index 000000000..c76a43df0 --- /dev/null +++ b/test/float8/test_inference_flows.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import copy +import io +import random +import unittest + +import pytest + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchao.float8.config import ScalingType +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor import Float8Tensor +from torchao.float8.float8_utils import compute_error +from torchao.float8.inference import ( + ActivationCasting, + Float8InferenceLinear, + QuantConfig, + quantize_to_float8, +) + + +random.seed(0) +torch.manual_seed(0) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +class FeedForward(nn.Module): + def __init__(self) -> None: + super().__init__() + self.w1 = nn.Linear(4096, 14336, bias=False) + self.w3 = nn.Linear(4096, 14336, bias=False) + self.w2 = nn.Linear(14336, 4096, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + m.reset_parameters() + + +class TestHPTrainToFP8LinearInference: + def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor): + with torch.no_grad(): + base_output = base_mlp(input_tensor) + transformed_output = quantized_mlp(input_tensor) + + # Compute and check SQNR + sqnr = compute_error(base_output, transformed_output) + assert sqnr.item() > 20, f"SQNR is too low: {sqnr.item()} dB" + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_dynamic_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + dynamic_fp8_mlp = copy.deepcopy(original_mlp) + + quant_config = QuantConfig(ActivationCasting.DYNAMIC) + quantize_to_float8(dynamic_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_dynamic_fp8_mlp = torch.compile( + dynamic_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_dynamic_fp8_mlp, input_tensor + ) + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_static_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + static_fp8_mlp = copy.deepcopy(original_mlp) + quant_config = QuantConfig( + ActivationCasting.STATIC, + static_quantization_scale=torch.tensor( + [1.0], device="cuda", dtype=torch.float32 + ), + ) + quantize_to_float8(static_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_static_fp8_mlp = torch.compile( + static_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_static_fp8_mlp, input_tensor + ) + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_weight_only_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + static_fp8_mlp = copy.deepcopy(original_mlp) + quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY) + quantize_to_float8(static_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_static_fp8_mlp = torch.compile( + static_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_static_fp8_mlp, input_tensor + ) + + +class TestFP8TrainToFP8LinearInference: + def train(self, model: nn.Module, dtype: torch.dtype): + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + criterion = nn.MSELoss() + target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + for _ in range(10): + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + optimizer.zero_grad() + output = model(input_tensor) + loss = criterion(output, target_tensor) + loss.backward() + optimizer.step() + model.eval() + return model + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_fp8_save_and_load(self, dtype: torch.dtype): + # Initialize FP8 model + fp8_mlp = FeedForward().to("cuda", dtype=torch.float32) + fp8_mlp.reset_parameters() + convert_to_float8_training(fp8_mlp) + + # Train the model + self.train(fp8_mlp, dtype) + + # Generate input tensor and original out + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + og_out = fp8_mlp(input_tensor) + + # Save model state dict + buffer = io.BytesIO() + torch.save(fp8_mlp.state_dict(), buffer) + + # Reset buffer position to the beginning + buffer.seek(0) + + # Later on you load the model, will be w/ Float8Linear on meta device + with torch.device("meta"): + new_fp8_mlp = FeedForward().to(dtype=dtype) + convert_to_float8_training(new_fp8_mlp) + + # Load the actual data + new_fp8_mlp.load_state_dict( + torch.load(buffer, weights_only=True), strict=True, assign=True + ) + + quant_config = QuantConfig(ActivationCasting.DYNAMIC) + quantize_to_float8(new_fp8_mlp, quant_config) + + fp8_mod_count = 0 + for module in new_fp8_mlp.modules(): + if isinstance(module, Float8InferenceLinear): + assert isinstance(module.weight, Float8Tensor) + assert module.weight.requires_grad is False + fp8_mod_count += 1 + assert fp8_mod_count == 3, "Expected 3 FP8 modules, got {}".format( + fp8_mod_count + ) + + new_out = new_fp8_mlp(input_tensor) + + # Assert exact equality + assert torch.all(og_out == new_out).item() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py new file mode 100644 index 000000000..fd724b340 --- /dev/null +++ b/test/float8/test_numerics_integration.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# Tests LLaMa FeedForward numerics with float8 + +import copy +from typing import Optional + +import pytest + +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_utils import compute_error, IS_ROCM + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +torch.manual_seed(0) + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TestFloat8NumericsIntegrationTest: + @pytest.mark.parametrize( + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC], + ) + @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw( + self, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + ): + # TODO(later): maybe add float16 back if it becomes important + data_dtype = torch.bfloat16 + + # LLaMa 3 70B shapes + model_ref = ( + FeedForward( + dim=4096, + hidden_dim=16384, + multiple_of=1024, + ffn_dim_multiplier=1.3, + ) + .cuda() + .to(data_dtype) + ) + + # for now just test the encoder to simplify things + model_fp8 = copy.deepcopy(model_ref) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + ) + convert_to_float8_training( + model_fp8, + config=config, + ) + + lr = 0.01 + optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) + optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) + + # Note: you need two different inputs to properly test numerics + # of delayed scaling, because the first time around the initialization + # logic of delayed scaling behaves as dynamic scaling + # TODO(future): also make unit tests do this properly + shape = (1, 8192, 4096) + data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) + data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) + + model_ref(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_ref.zero_grad() + model_ref_out = model_ref(data2) + model_ref_out.sum().backward() + + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_fp8.zero_grad() + if linear_requires_sync(config): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8_out = model_fp8(data2) + model_fp8_out.sum().backward() + + out_sqnr = compute_error(model_ref_out, model_fp8_out) + assert out_sqnr > 20.0 + + ref_name_to_grad = { + name: param.grad for name, param in model_ref.named_parameters() + } + + grad_sqnr_threshold = 20.0 + + for name, param in model_fp8.named_parameters(): + ref_grad = ref_name_to_grad[name] + cur_grad = param.grad + sqnr = compute_error(ref_grad, cur_grad) + assert sqnr > grad_sqnr_threshold + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchao/float8/README.md b/torchao/float8/README.md new file mode 100644 index 000000000..abab3b9fa --- /dev/null +++ b/torchao/float8/README.md @@ -0,0 +1,159 @@ +# torchao.float8 + +This is an early version of a library for accelerating training with float8 in native PyTorch +according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. +The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling, +and composable with key systems such as autograd, ```torch.compile``` and distributed. +With ``torch.compile`` on, initial results show +throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. + +:warning: See the [feature tracker](https://github.com/pytorch-labs/torchao.float8/issues/187) for upcoming features. + +:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and +will change rapidly. + +# Single GPU User API + +We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). + +## float8 linear with dynamic scaling for `input`, `weight` and `grad_output` + +This is the most accurate recipe as every tensor is scaled dynamically. + +```python +from torchao.float8 import ( + convert_to_float8_training, + precompute_float8_dynamic_scale_for_fsdp, +) + +# create model +m = Model(...) + +# optional: filter modules from being eligible for float8 conversion +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "output": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + +# convert all `torch.nn.Linear` modules to `Float8Linear` +convert_to_float8_training(m, module_filter_fn=module_filter_fn) + +# optional: use FSDP +model = FSDP(model, use_orig_params=True) + +# optional: enable torch.compile for improved performance +m = torch.compile(m) + +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on + # this method is optional but is highly recommended for performance + # it calcuclates scales for all parameters in a single all-reduce + precompute_float8_dynamic_scale_for_fsdp(model) + +``` + +## float8 linear with delayed scaling + +This is theoretically the most performant recipe as it minimizes memory reads. + +```python +from torchao.float8 import ( + convert_to_float8_training, + sync_float8_amax_and_scale_history, + ScalingType, +) + +# create model +m = Model(...) + +# optional: configure for compatibility with FSDP. Note that workarounds +# gated with config.enable_amax_init and +# config.enable_pre_and_post_forward are needed for +# autocast + compile + FSDP + float8 to work +from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig +config = Float8LinearConfig( + enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed + enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), +) + +# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling +# type +convert_to_float8_training( + m, + config=config, +) + +# optional: use FSDP +model = FSDP(model, use_orig_params=True) + +# optional: enable torch.compile for improved performance +m = torch.compile(m) + +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + + # specific to float8 with delayed scaling: separate step to sync scales/amaxes + # in the future, this may move to a context manager + sync_float8_amax_and_scale_history(model) + + optimizer.step() +``` + +# Multi GPU User API + +We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), +such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples +on using `torchao.float8` in a distributed setting. + +# Testing + +```bash +# run single-GPU unit tests +pytest test/float8/test_base.py + +# run single-GPU compile tests +pytest test/float8/test_compile.py + +# run single-GPU numerics integration tests +pytest test/float8/test_numerics_integration.py + +# run a two-GPU integration test on FSDP +./test/float8/test_fsdp.sh + +# run integration tests on the DTensor TP/SP integration +./test/float8/test_dtensor.sh + +# run integration tests on the FSDP2 integration +python test/float8/test_fsdp2/test_fsdp2.py + +# run all of these tests +./test/float8/test_everything.sh +``` + +# Benchmarking + +```bash +# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes +./benchmarks/float8/bench_matmul.py + +# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes +# make sure to turn on torch.compile to get the best performance +./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile +``` diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py new file mode 100644 index 000000000..56c7b28f7 --- /dev/null +++ b/torchao/float8/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# Lets define a few top level things here +from torchao.float8.config import ( + CastConfig, + DelayedScalingConfig, + Float8GemmConfig, + Float8LinearConfig, + ScalingType, +) +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import ( + convert_to_float8_training, + linear_requires_sync, + sync_float8_amax_and_scale_history, +) +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) +from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + +# Needed to load Float8Tensor with weights_only = True +from torch.serialization import add_safe_globals + +add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) + +__all__ = [ + # configuration + "DelayedScalingConfig", + "ScalingType", + "Float8GemmConfig", + "Float8LinearConfig", + "CastConfig", + # top level UX + "convert_to_float8_training", + "linear_requires_sync", + "sync_float8_amax_and_scale_history", + "precompute_float8_dynamic_scale_for_fsdp", + # note: Float8Tensor and Float8Linear are not public APIs +] diff --git a/torchao/float8/config.py b/torchao/float8/config.py new file mode 100644 index 000000000..5d1bf9f54 --- /dev/null +++ b/torchao/float8/config.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import enum +from dataclasses import dataclass + + +# TODO(future): consider renaming to ScalingType +class ScalingType(enum.Enum): + DELAYED = "delayed" + DYNAMIC = "dynamic" + + def short_str(self): + if self is ScalingType.DELAYED: + return "del" + else: + assert self is ScalingType.DYNAMIC + return "dyn" + + +@dataclass(frozen=True) +class CastConfig: + """ + Configuration for casting a single tensor to float8 + """ + + scaling_type: ScalingType = ScalingType.DYNAMIC + + +@dataclass(frozen=True) +class DelayedScalingConfig: + """ + Configuration for delayed scaling. + + Note: for now, `history_len` values must be the same for all layers in the + model using delayed scaling. + + TODO(future): serialization for recipes + """ + + # Controls the history length of amax buffers + history_len: int = 16 + + # Controls the way to calculate current scale from amax history + # TODO(future): add other functions as needed, hardcoded or user defined + scale_fn_name: str = "max" + + def __post_init__(self): + assert ( + self.scale_fn_name == "max" + ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." + + +@dataclass(frozen=True) +class Float8GemmConfig: + """ + Configuration for a float8 gemm. + """ + + # If True, fast accumulation in lower precision is used. + # Note: this flag is currently a no-op if emulation is turned on. + use_fast_accum: bool = False + + +@dataclass(frozen=True) +class Float8LinearConfig: + """ + Configuration for converting a `torch.nn.Linear` module to float8 + for training. + """ + + # + # Per-tensor configuration for `input`, `weight`, `grad_output` + # + cast_config_input: CastConfig = CastConfig() + cast_config_weight: CastConfig = CastConfig() + cast_config_grad_output: CastConfig = CastConfig() + + # + # Per-gemm configuration for gemms calculating `output`, `grad_input` and + # `grad_weight` + # + gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() + + # + # Per-linear configuration + # + + # If True, on the first iteration of Float8Linear the amaxes will be + # initialized with the incoming data. As of 2023-12-30, this doesn't work + # with autocast + torch.compile + FSDP. Enabling this option is nice for + # testing, but this is not necessary for real training jobs. + enable_amax_init: bool = True + + # If True, pre-forward and post-forward functions are run. As of 2023-12-30, + # this doesn't work with autocast + torch.compile + FSDP. Enabling this + # option is useful for safety, but not strictly necessary. + enable_pre_and_post_forward: bool = True + + # If True, then uses a tensor subclass for the float8 linear module's weight that + # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. + enable_fsdp_float8_all_gather: bool = False + + # If True, then prior to performing the fp8 scaled mamtmul we will pad the + # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls + # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. + # This can cause a memory spike however so we keep this off by default. + pad_inner_dim: bool = False + + # If True, emulation is used instead of hardware accelerated gemm + emulate: bool = False + + # Configuration for delayed scaling + # Note: this is actually applied per-tensor, but only using the same + # configuration for all tensors and layers in the model is currently + # supported. If in the future we add support for a more fine grained + # configuration, this field may move to per-tensor configs. + delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + + +# If True, use 'fnuz' float8 types for calculations. +# Currently, ROCm only supports fnuz variants. +# TODO(future PR): move this to Float8LinearConfig +use_fnuz_dtype = False diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py new file mode 100644 index 000000000..ef174b073 --- /dev/null +++ b/torchao/float8/distributed_utils.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any + +import torch + +from fairscale.nn.model_parallel.initialize import get_model_parallel_group + +# from float8_tensor import Float8Tensor +from torchao.float8.float8_tensor import Float8Tensor + +# additional differentiable distributed primitives for SP which are not in +# the Fairscale codebase + + +def _gather_along_first_dim(input_: torch.Tensor): + # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67, + # but gather along first dim instead of last dim + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # Size and dimension. + first_dim = 0 + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + + # If the input is a float8 tensor, we need to do the transformation on the + # inner tensor and then return a new wrapper. + def _transform(t): + # tensors must be contiguous for all_gather to work + input_contig = t.contiguous() + + tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)] + tensor_list[rank] = input_contig + torch.distributed.all_gather(tensor_list, input_contig, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=first_dim).contiguous() + return output + + if isinstance(input_, Float8Tensor): + new_data = input_._data + new_data = new_data.view(torch.int8) + new_data = _transform(new_data) + new_data = new_data.view(input_._data.dtype) + output = Float8Tensor(new_data, input_._scale, input_._orig_dtype) + else: + output = _transform(input_) + + return output + + +def _reduce_scatter(ctx: Any, input_: torch.Tensor): + group = get_model_parallel_group() + world_size = torch.distributed.get_world_size(group) + + assert input_.shape[0] % world_size == 0 + output_shape = (input_.shape[0] // world_size, *input_.shape[1:]) + output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype) + + torch.distributed.reduce_scatter_tensor(output, input_, group=group) + return output + + +def _split_along_first_dim(input_: torch.Tensor): + # this is needed for testing + + # like fairscale.nn.model_parallel.mappings._split, but + # along the first dim instead of last dim + + group = get_model_parallel_group() + local_rank = torch.distributed.get_rank(group) + world_size = torch.distributed.get_world_size(group) + + assert input_.shape[0] % world_size == 0 + input_list = torch.split(input_, input_.shape[0] // world_size) + return input_list[local_rank] + + +class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _reduce_scatter(ctx, grad_output) + + +class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _reduce_scatter(ctx, input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _AllGatherFwSplitBw(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_along_first_dim(grad_output) diff --git a/torchao/float8/float8_aten_api.py b/torchao/float8/float8_aten_api.py new file mode 100644 index 000000000..41d5083d6 --- /dev/null +++ b/torchao/float8/float8_aten_api.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +This file defines the aten functions for float8. Today, all of these functions +are emulated. In the future, they should be calling NVIDIA's float8 kernels. +""" + +import torch + +from torch.library import Library + + +def mm_float8_emulated( + m1, # input 1 data + s1, # input 1 scale + m2, # input 2 data + s2, # input 2 scale + dtype3, # output dtype +): + # naive implementation: dq -> op -> q + m1_fp32 = m1.float() / s1 + m2_fp32 = m2.float() / s2 + m3_fp32 = torch.mm(m1_fp32, m2_fp32) + + return m3_fp32.to(dtype3) + + +# +# ATen op placeholders +# + +# Register the aten level functions we need. +# These are mostly placeholder and might need to be implemented in c++ as needed +lib = Library("aten", "FRAGMENT") + +lib.define( + "mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor" +) +lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU") +lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA") + + +@torch.library.impl(lib, "mm_float8_emulated", "Meta") +def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3): + out = torch.mm(m1.float(), m2.float()).to(dtype3) + return out diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py new file mode 100644 index 000000000..dd85af921 --- /dev/null +++ b/torchao/float8/float8_linear.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +A simple module swap UX for a float8 version of `torch.nn.Linear`. +""" + +import dataclasses +import enum + +from typing import Optional + +import torch + +from torchao.float8.config import Float8LinearConfig, ScalingType + +from torchao.float8.float8_scaling_utils import ( + _maybe_initialize_amaxes_scales_for_float8_cast, + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, + NoopFwToFloat8E5M2BwDelayed, + NoopFwToFloat8E5M2BwDynamic, +) + +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) + +from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax + +from torchao.float8.fsdp_utils import ( + WeightWithDelayedFloat8CastTensor, + WeightWithDynamicFloat8CastTensor, +) + + +# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files +@torch._dynamo.allow_in_graph +class manual_float8_matmul(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in float8 + """ + + @staticmethod + def forward( + ctx, + input_fp8, + weight_fp8_t, + ): + ctx.save_for_backward(input_fp8, weight_fp8_t) + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output_fp8): + input_fp8, weight_fp8_t = ctx.saved_tensors + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_fp8_orig_shape = grad_output_fp8.shape + grad_output_fp8_reshaped = grad_output_fp8.reshape( + -1, grad_output_fp8_orig_shape[-1] + ) + + # calculate grad_input + grad_input = torch.mm( + grad_output_fp8_reshaped, + weight_fp8_t.t(), + ) + grad_input = grad_input.reshape( + *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1] + ) + + input_fp8_orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1]) + + # calculate grad_weight + # Note: the variant below is slightly faster on LLaMa 3 8B pretraining + # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped` + grad_weight = torch.mm( + grad_output_fp8_reshaped.t(), + input_fp8_reshaped, + ) + + return grad_input, grad_weight.t() + + +class Float8Linear(torch.nn.Linear): + """ + Note: this is **not** a public API and is only intended to be used + inside of this repository. Please file an issue if you would benefit + from this being a public API. + + A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks + scales in way friendly to delayed scaling. + """ + + def __init__(self, *args, **kwargs): + """ + Additional arguments on top of `torch.nn.Linear`'s arguments: + * `config`: Float8LinearConfig + """ + + # Amax scales should always be kept as float32. + self.always_float32_buffers = set() + config = kwargs.pop("config") + emulate = config.emulate + super().__init__(*args, **kwargs) + + # Defines the scaling behavior of input, weight, grad_output + self.scaling_type_input = config.cast_config_input.scaling_type + self.scaling_type_weight = config.cast_config_weight.scaling_type + self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type + # Convenience flag to skip code related to delayed scaling + self.has_any_delayed_scaling = ( + self.scaling_type_input is ScalingType.DELAYED + or self.scaling_type_weight is ScalingType.DELAYED + or self.scaling_type_grad_output is ScalingType.DELAYED + ) + + self.config = config + + self.create_buffers() + + self.linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + emulate, + self.config.gemm_config_output.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + ) + + # Note: is_amax_initialized is not a buffer to avoid data dependent + # control flow visible to dynamo + # TODO(future PR): add serialization for this flag + self.is_amax_initialized = not self.config.enable_amax_init + + # Syncing of amaxes and scales happens outside of this function. This + # flag is here to enforce that the user does not forget to do this. + self.amax_and_scale_synced = not self.config.enable_amax_init + + # This is needed to properly handle autocast in the amax/scale + # update function for torch.float16 + self.last_seen_input_dtype = None + + # pre_forward and post_forward are currently broken with FSDP + # and torch.compile, this option can disable them + # Note that when using `self.config.enable_pre_and_post_forward = False`, + # it's recommended to also set `self.config.enable_amax_init = False`. + # Otherwise, the amax buffer would never be marked as initialized and + # would be initialized in every iteration. + self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward + + def create_buffers(self): + # Default values for history buffers, see above TODO + history_len = self.config.delayed_scaling_config.history_len + device = self.weight.device + # TODO(future PR): dtype values below don't have the other float8 + # flavors, fix it + default_input = torch.finfo(torch.float8_e4m3fn).max + default_weight = torch.finfo(torch.float8_e4m3fn).max + default_grad_output = torch.finfo(torch.float8_e5m2).max + + # Note: for now, create all the buffers if any are needed, to postpone + # the work to make the scale and amax syncing and history calculation + # handle a heterogeneous setup. We can do that work later if benchmarks + # show it is worth doing. + if self.has_any_delayed_scaling: + self.register_always_float32_buffer( + "fp8_amax_input", torch.tensor([default_input], device=device) + ) + self.register_always_float32_buffer( + "fp8_amax_history_input", torch.zeros(history_len, device=device) + ) + self.register_always_float32_buffer( + "fp8_scale_input", torch.tensor([1.0], device=device) + ) + self.register_always_float32_buffer( + "fp8_amax_weight", torch.tensor([default_weight], device=device) + ) + self.register_always_float32_buffer( + "fp8_amax_history_weight", torch.zeros(history_len, device=device) + ) + self.register_always_float32_buffer( + "fp8_scale_weight", torch.tensor([1.0], device=device) + ) + self.register_always_float32_buffer( + "fp8_amax_grad_output", + torch.tensor([default_grad_output], device=device), + ) + self.register_always_float32_buffer( + "fp8_amax_history_grad_output", torch.zeros(history_len, device=device) + ) + self.register_always_float32_buffer( + "fp8_scale_grad_output", torch.tensor([1.0], device=device) + ) + + def register_always_float32_buffer( + self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True + ) -> None: + self.register_buffer(name=name, tensor=tensor, persistent=persistent) + self.always_float32_buffers.add(name) + + def _apply(self, fn, recurse=True): + ret = super()._apply(fn, recurse) + self.convert_amax_buffer_to_float32() + return ret + + def convert_amax_buffer_to_float32(self): + for key in self.always_float32_buffers: + if self._buffers[key] is not None: + self._buffers[key] = self._buffers[key].to(torch.float32) + + def cast_input_to_float8( + self, input: torch.Tensor, is_amax_initialized: bool + ) -> torch.Tensor: + # Duplicate the autocast logic for F.linear, so that the output + # of our module has the right original precision + if torch.is_autocast_enabled(): + # For now, hardcode to GPU's autocast dtype + # if we need CPU support in the future, we can add it + autocast_dtype = torch.get_autocast_gpu_dtype() + input = input.to(autocast_dtype) + + if self.scaling_type_input is ScalingType.DELAYED: + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + input, + self.fp8_amax_input, + self.fp8_amax_history_input, + self.fp8_scale_input, + scale_fn_name, + e4m3_dtype, + is_amax_initialized, + reduce_amax=True, + ) + input_fp8 = hp_tensor_to_float8_delayed( + input, + self.fp8_scale_input, + e4m3_dtype, + self.fp8_amax_input, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) + else: + assert self.scaling_type_input is ScalingType.DYNAMIC + input_fp8 = hp_tensor_to_float8_dynamic( + input, e4m3_dtype, self.linear_mm_config + ) + return input_fp8 + + def cast_weight_to_float8( + self, weight: torch.Tensor, is_amax_initialized: bool + ) -> torch.Tensor: + if self.scaling_type_weight is ScalingType.DELAYED: + if isinstance(self.weight, Float8Tensor): # cast by FSDP + weight_fp8 = self.weight + else: + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + weight, + self.fp8_amax_weight, + self.fp8_amax_history_weight, + self.fp8_scale_weight, + scale_fn_name, + e4m3_dtype, + is_amax_initialized, + reduce_amax=False, + ) + + weight_fp8 = hp_tensor_to_float8_delayed( + weight, + self.fp8_scale_weight, + e4m3_dtype, + self.fp8_amax_weight, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + else: + assert self.scaling_type_weight is ScalingType.DYNAMIC + if isinstance(self.weight, Float8Tensor): # cast by FSDP + weight_fp8 = self.weight + else: + weight_fp8 = hp_tensor_to_float8_dynamic( + self.weight, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + return weight_fp8 + + def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: + if self.scaling_type_grad_output is ScalingType.DELAYED: + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name + output = NoopFwToFloat8E5M2BwDelayed.apply( + output, + self.fp8_amax_grad_output, + self.fp8_amax_history_grad_output, + self.fp8_scale_grad_output, + scale_fn_name, + self.is_amax_initialized, + self.linear_mm_config, + ) + else: + assert self.scaling_type_grad_output is ScalingType.DYNAMIC + output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) + return output + + def float8_pre_forward(self, input): + if not self.enable_pre_and_post_forward: + return + if ( + self.is_amax_initialized + and (not self.amax_and_scale_synced) + and torch.is_grad_enabled() + ): + raise AssertionError( + "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward" + ) + self.last_seen_input_dtype = input.dtype + + def float8_post_forward(self): + if not self.enable_pre_and_post_forward: + return + # Ensure that calling forward again will fail until the user syncs + # amaxes and scales + self.is_amax_initialized = True + self.amax_and_scale_synced = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.has_any_delayed_scaling: + self.float8_pre_forward(input) + + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) + + output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) + + if self.bias is not None: + output = output + self.bias.to(output.dtype) + + if self.has_any_delayed_scaling: + self.float8_post_forward() + return output + + def scaling_repr(self): + # add scaling settings without using too many characters + # example: "i:del,w:del,go:dyn" + return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" + + def extra_repr(self): + s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' + return s + + @classmethod + def from_float( + cls, + mod, + config: Optional[Float8LinearConfig] = None, + ): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + config (Optional[Float8LinearConfig]): configuration for conversion to float8 + """ + if config is None: + config = Float8LinearConfig() + with torch.device("meta"): + new_mod = cls( + mod.in_features, + mod.out_features, + bias=False, + config=config, + ) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + # need to create buffers again when moving from meta device to + # real device + new_mod.create_buffers() + + # If FSDP float8 all-gather is on, wrap the weight in a float8-aware + # tensor subclass. This must happen last because: + # 1. weight needs to be on the correct device to create the buffers + # 2. buffers need to be already created for the delayed scaling version + # of the weight wrapper to be initialized + if config.enable_fsdp_float8_all_gather: + if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: + new_mod.weight = torch.nn.Parameter( + WeightWithDynamicFloat8CastTensor( + new_mod.weight, + new_mod.linear_mm_config, + ) + ) + else: + assert config.cast_config_weight.scaling_type is ScalingType.DELAYED + new_mod.weight = torch.nn.Parameter( + WeightWithDelayedFloat8CastTensor( + new_mod.weight, + new_mod.fp8_amax_weight, + new_mod.fp8_amax_history_weight, + new_mod.fp8_scale_weight, + new_mod.linear_mm_config, + new_mod.is_amax_initialized, + ) + ) + + return new_mod diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py new file mode 100644 index 000000000..675ed5ee6 --- /dev/null +++ b/torchao/float8/float8_linear_utils.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +from typing import Callable, List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.float8_linear import Float8Linear + +from torchao.float8.float8_utils import ( + amax_history_to_scale_stack, + e4m3_dtype, + e5m2_dtype, +) +from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor + +log = logging.getLogger(__name__) +log.addHandler(logging.NullHandler()) + + +def linear_requires_sync(config: Float8LinearConfig): + """Returns whether the given linear_type requires sync before forward.""" + return any( + [ + config.cast_config_input.scaling_type is ScalingType.DELAYED, + config.cast_config_weight.scaling_type is ScalingType.DELAYED, + config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, + ] + ) + + +def _update_history_stack( + new_amax: torch.Tensor, amax_history_stack: torch.Tensor +) -> torch.Tensor: + """ + Updates `amax_history` (the last N cur_amax values) inplace with the value + of `new_amax`. + + Args: + new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) + amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) + """ + assert ( + amax_history_stack.dim() == 2 + ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" + assert new_amax.size(0) == amax_history_stack.size( + 0 + ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" + new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) + new_amax_history_stack[:, 0] = new_amax.squeeze(-1) + amax_history_stack.copy_(new_amax_history_stack) + + +def swap_linear_layers( + module: nn.Module, + from_float_func: Callable[[nn.Linear], nn.Linear], + *, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, +) -> nn.Module: + """ + Generic function to swap linear layers in a module with a new type of linear layer. + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead + + Args: + module: Module to modify. + from_float_func: Function that accepts a linear layer and returns a new type of linear layer. + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that + that pass the filter function will be swapped. The inputs to the + filter function are the module instance, and the FQN. + + Returns: + nn.Module: The modified module with swapped linear layers. + """ + if isinstance(module, nn.Linear) and ( + module_filter_fn is None or module_filter_fn(module, "") + ): + if len(list(module.children())) > 0: + raise AssertionError( + f"Does not support a root nn.Linear with children: {module}" + ) + return from_float_func( + module, + ) + + root_module = module + + def post_order_traversal( + module: nn.Module, + cur_fqn: Optional[str] = None, + parent_module: Optional[nn.Module] = None, + ): + if cur_fqn is None: + cur_fqn = "" + + for child_module_name, child_module in module.named_children(): + if cur_fqn == "": + new_fqn = child_module_name + else: + new_fqn = f"{cur_fqn}.{child_module_name}" + + post_order_traversal(child_module, new_fqn, module) + + if isinstance(module, nn.Linear) and ( + module_filter_fn is None or module_filter_fn(module, cur_fqn) + ): + assert ( + parent_module is not None + ), f"Linear root module should return early: {module}" + new_linear_module = from_float_func(module) + cur_module_name = cur_fqn.split(".")[-1] + setattr(parent_module, cur_module_name, new_linear_module) + + post_order_traversal(root_module) + return root_module + + +def convert_to_float8_training( + module: nn.Module, + *, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, + config: Float8LinearConfig = None, +) -> nn.Module: + """ + Swaps `torch.nn.Linear` in `module` with `Float8Linear`. + + Args: + module: Module to modify. + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that + that pass the filter function will be swapped. The inputs to the + filter function are the module instance and the FQN. + config (Float8LinearConfig): configuration for conversion to float8 + + Returns: + nn.Module: The modified module with swapped linear layers. + """ + if config is None: + config = Float8LinearConfig() + from_float = lambda m: Float8Linear.from_float( + m, + config=config, + ) + return swap_linear_layers( + module, + from_float, + module_filter_fn=module_filter_fn, + ) + + +def get_float8_layers(model: torch.nn.Module): + """Iterates through the model and returns all the Float8Linear layers. + Args: + model (torch.nn.Module): The model to look for Float8Linear layers in. + """ + + # Get all fp8 layers and tensors + fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] + if not torch._dynamo.is_compiling(): + for layer in fp8_layers: + for buf in layer.buffers(): + torch._dynamo.mark_static_address(buf, guard=True) + return fp8_layers + + +@torch.no_grad() +def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: + """ + Manages the float8 amax and scale bookkeeping. In detail, it does the + following: + 1. in distributed contexts, syncs amax values across workers for activations and gradients + 2. adds the `amax` values to history + 3. calculates the scales to be used for next iteration + 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules + to signal that they have been synced + + TODO(future): design the UX for this (context manager, etc) + + PERFORMANCE NOTE: + When you can, it is much more efficient to call get_float8_layers once at + the beginning of the training loop and pass the result to this function. + Because of how this interacts with torch.compile + + Args: + model (torch.nn.Module): The model to track amaxes for + fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, + and we loop over all fp8_layers to sync and update amax scale histories. + Users can use get_float8_layers to get all fp8 layers. + """ + if fp8_layers is None: + fp8_layers = get_float8_layers(model) + + if len(fp8_layers) == 0: + log.warn( + "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" + ) + return + + def inner_func(): + """Why do we have this inner_function? + + There are two portions of the outer sync_function that cause graph_breaks: + 1. The `get_float8_layers` call can cause graph breaks if the user did not pass + in the fp8_layers. + 2. At the end of syncing all the amaxes and scales we set the attr on the module + signaling that we have synced the amaxes and scales and the next forward can be run. + # TODO Maybe we should remove this safety check to remove the graph break? + + By having this inner function, we can ensure that although the outer function may cause graph breaks + the inner function will not. + """ + # Loop over all fp8 layers and grab the needed tensors + fp8_amax_input_tensor_list = [None] * len(fp8_layers) + fp8_amax_weight_tensor_list = [None] * len(fp8_layers) + fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) + + fp8_input_amax_history_stack = [None] * len(fp8_layers) + fp8_weight_amax_history_stack = [None] * len(fp8_layers) + fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) + + x_dtypes = set() + scale_fn_recipes = set() + + for idx, child in enumerate(fp8_layers): + fp8_amax_input_tensor_list[idx] = child.fp8_amax_input + fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight + fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output + + fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input + fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight + fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output + + x_dtypes.add(child.last_seen_input_dtype) + scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) + + # TODO This way to get the activation dtype is not ideal + if len(x_dtypes) != 1: + raise ValueError( + f"All layers must have the same last seen input_dtype, got {x_dtypes}" + ) + x_dtype = next(iter(x_dtypes)) + + if len(scale_fn_recipes) != 1: + raise ValueError( + f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" + ) + scale_fn_recipe = next(iter(scale_fn_recipes)) + + assert ( + len(fp8_amax_input_tensor_list) + == len(fp8_amax_weight_tensor_list) + == len(fp8_amax_grad_output_tensor_list) + ), "Mismatched lengths of amax tensors." + + if dist.is_initialized(): + all_amax_tensors = torch.cat( + fp8_amax_input_tensor_list + + fp8_amax_weight_tensor_list + + fp8_amax_grad_output_tensor_list + ) + all_reduced_amax_tensor = all_reduce( + all_amax_tensors, "MAX", list(range(dist.get_world_size())) + ) + if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): + all_reduced_amax_tensor = all_reduced_amax_tensor.wait() + + ( + reduced_fp8_amax_input_tensor, + reduced_fp8_amax_weight_tensor, + reduced_fp8_amax_grad_output_tensor, + ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) + + for idx, child in enumerate(fp8_layers): + child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) + child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) + child.fp8_amax_grad_output.copy_( + reduced_fp8_amax_grad_output_tensor[idx] + ) + + # We create two stacked tensor groups, one for the amax history and one for the current scales + fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) + fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) + fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) + + fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) + fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) + fp8_grad_output_amax_history_stack = torch.vstack( + fp8_grad_output_amax_history_stack + ) + + # Update the history stacks with the new amax values + _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) + _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) + _update_history_stack( + fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack + ) + + # Calculate the new scales from the updated history stacks + new_input_scales = amax_history_to_scale_stack( + fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + ) + new_weight_scales = amax_history_to_scale_stack( + fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + ) + new_grad_output_scales = amax_history_to_scale_stack( + fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe + ) + + # Iterate through the layers and update the scales + for idx, child in enumerate(fp8_layers): + child.fp8_scale_input.copy_(new_input_scales[idx]) + child.fp8_scale_weight.copy_(new_weight_scales[idx]) + child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) + + # This allows for the compile to succede on the inner func and fail on the graph breaks + # at the beginning and and of syncing + inner_func() + + for child in fp8_layers: + # Set a flag to signal amaxes/scales are ready + child.amax_and_scale_synced = True diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py new file mode 100644 index 000000000..d3c3b405b --- /dev/null +++ b/torchao/float8/float8_ops.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, Tuple + +import torch + +from torchao.float8.float8_python_api import addmm_float8_unwrapped +from torchao.float8.float8_tensor import choose_scaled_mm_config, Float8Tensor +from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul + +from torch.utils._pytree import tree_map + +aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional +FLOAT8_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the float8 op table""" + + def decorator(func): + for op in aten_ops: + FLOAT8_OPS_TABLE[op] = func + return func + + return decorator + + +@implements( + [ + aten.view.default, + aten._unsafe_view.default, + aten.t.default, + aten.as_strided.default, + aten.clone.default, + aten.detach.default, + aten.slice.Tensor, + aten.transpose.int, + aten.fill_.Scalar, + ] +) +def float8_desugar_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + return Float8Tensor( + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + +@implements([aten.split.Tensor]) +def float8_split(aten_op, args, kwargs=None): + new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) + + def make_float8(data): + return Float8Tensor( + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + out = map(make_float8, new_data_tensors) + return list(out) + + +# Errors cant `cat_cuda float8 e4m3fn` +@implements([aten.cat.default]) +def float8_cat(aten_op, args, kwargs=None): + chunked_tensors: Tuple[Float8Tensor] = args[0] + + orig_dtype = chunked_tensors[0]._orig_dtype + scale = chunked_tensors[0]._scale + mm_config = chunked_tensors[0]._linear_mm_config + fp8_dtype = chunked_tensors[0]._data.dtype + gemm_input_role = chunked_tensors[0]._gemm_input_role + chunk_data = [] + for chunk in chunked_tensors: + assert isinstance( + chunk, Float8Tensor + ), "Expecting all chunks to be of type Float8Tensor" + assert ( + chunk._orig_dtype == orig_dtype + ), "Expecting all chunks to be of the same dtype" + assert ( + chunk._scale is scale + ), "Expecting all chunks to have thee same scale as a result of a split" + assert ( + chunk._linear_mm_config is mm_config + ), "Expecting all chunks to have thee same mm config as a result of a split" + assert ( + chunk._data.dtype == fp8_dtype + ), "Expecting all chunks to be of the same dtype as a result of a split" + assert ( + chunk._gemm_input_role is gemm_input_role + ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + chunk_data.append(chunk._data.view(torch.uint8)) + + new_data = aten_op(chunk_data, *args[1:], **kwargs) + new_data = new_data.view(fp8_dtype) + return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) + + +@implements([aten.sum.dim_IntList]) +def float8_cast_up_op(aten_op, args, kwargs=None): + """Be careful with this function, this is a "fallback" op that + casts the output of the op to the original precision. And performs the op. + + We currently need this to support the backward for admmm bias. + "addmm" -> out + "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" + """ + + def unwrap(x): + if isinstance(x, Float8Tensor): + return x.to_original_precision() + return x + + new_args = tree_map(unwrap, args) + new_kwargs = tree_map(unwrap, kwargs) + return aten_op(*new_args, **new_kwargs) + + +def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): + a_data = a._data + a_scale = a._scale + b_data = b._data + + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + + if scaled_mm_config.pad_inner_dim: + assert a._data.size(1) == b._data.size( + 0 + ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" + a_data = pad_tensor_for_matmul(a_data, dims=1) + b_data = pad_tensor_for_matmul(b_data, dims=0) + + if not is_row_major(a_data.stride()): + a_data = a_data.contiguous() + if is_row_major(b_data.stride()): + b_data = b_data.t().contiguous().t() + b_scale = b._scale + return a_data, a_scale, b_data, b_scale + + +@implements([aten.mm.default, aten.matmul.default]) +def float8_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + assert isinstance(a, Float8Tensor) and isinstance( + b, Float8Tensor + ), "Expecting both Float8Tensor for mm inputs but found {} and {}".format( + type(a), type(b) + ) + a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + output_dtype = a._orig_dtype + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: + return torch.ops.aten.mm_float8_emulated( + a._data, a._scale, b._data, b._scale, output_dtype + ) + tensor_out = addmm_float8_unwrapped( + a_data, + a_scale, + b_data, + b_scale, + output_dtype, + output_scale=None, + bias=None, + use_fast_accum=scaled_mm_config.use_fast_accum, + ) + return tensor_out + + +@implements([aten.addmm.default]) +def float8_addmm(aten_op, args, kwargs=None): + assert ( + isinstance(args[0], torch.Tensor) + and isinstance(args[1], Float8Tensor) + and isinstance(args[2], Float8Tensor) + ) + bias = args[0] + a = args[1] + b = args[2] + a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + output_dtype = a._orig_dtype + assert bias.dtype == output_dtype, "bias dtype must match output dtype" + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: + out = torch.ops.aten.mm_float8_emulated( + a._data, a._scale, b._data, b._scale, output_dtype + ) + return out + bias + tensor_out = addmm_float8_unwrapped( + a_data, + a_scale, + b_data, + b_scale, + output_dtype, + output_scale=None, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ) + return tensor_out + + +@implements([aten.is_same_size.default]) +def float8_is_same_size(aten_op, args, kwargs=None): + return args[0].shape == args[1].shape + + +@implements([aten._to_copy.default]) +def autocast_to_copy(aten_op, args, kwargs=None): + """This gets called when running matmul under autocast + when the input is a Float8Tensor, presenting as a fp32 + tensor. + """ + assert isinstance(args[0], Float8Tensor) + assert ( + len(kwargs) == 1 and "dtype" in kwargs + ), "Only support dtype kwarg for autocast" + assert kwargs["dtype"] in { + torch.float16, + torch.bfloat16, + }, "Only support floating point conversion for autocast w/ Float8Tensor" + return Float8Tensor( + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + +@implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + ] +) +def allgather_fp8(aten_op, args, kwargs=None): + """ + override funcol with FP8 handling + """ + fp8_input = args[0] + assert isinstance( + fp8_input, Float8Tensor + ), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}" + + fp8_data = fp8_input._data + fp8_data = fp8_data.contiguous() + fp8_out = aten_op(fp8_data, *args[1:], **kwargs) + return Float8Tensor( + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, + ) + + +@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) +def wait_tensor_fp8(aten_op, args, kwargs=None): + fp8_input = args[0] + assert isinstance(fp8_input, Float8Tensor) + + fp8_data = fp8_input._data + fp8_out = aten_op(fp8_data, *args[1:], **kwargs) + return Float8Tensor( + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, + ) + + +@implements([aten.index_put_.default]) +def index_put_fp8(aten_op, args, kwargs=None): + fp8_self = args[0] + fp8_values = args[2] + assert isinstance(fp8_self, Float8Tensor) + assert isinstance(fp8_values, Float8Tensor) + assert fp8_self._scale == fp8_values._scale + assert fp8_self.dtype == fp8_values.dtype + assert fp8_self._orig_dtype == fp8_values._orig_dtype + + fp8_data = fp8_self._data + fp8_values_data = fp8_values._data + fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) + return Float8Tensor( + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._linear_mm_config, + fp8_self._gemm_input_role, + ) + + +@implements([aten.copy_.default]) +def copy_fp8(aten_op, args, kwargs=None): + # For a copy op with Float8Tensors involved, only the following combinations are allowed: + # 1. self is a high precision (hp) tensor, src is a Float8Tensor: + # in this case src is upcasted and unscaled to go into the hp tensor + # 2. self and src are Float8Tensors: + # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) + # Every other combination is banned as the semantics are not well defined + + self = args[0] + src = args[1] + + if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + src_hp = src.to_original_precision() + return aten_op(self, src_hp, *args[2:], **kwargs) + elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + assert ( + self._orig_dtype == src._orig_dtype + ), "Expecting both Float8Tensors to be of the same dtype" + assert ( + self._scale == src._scale + ), "Expecting both Float8Tensors to have thee same scale" + assert ( + self._linear_mm_config == src._linear_mm_config + ), "Expecting both Float8Tensors to have thee same mm config" + assert ( + self._data.dtype == src._data.dtype + ), "Expecting both Float8Tensors to be of the same dtypet" + assert ( + self._gemm_input_role == src._gemm_input_role + ), "Expecting both Float8Tensors to have the same gemm_input_role" + fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._linear_mm_config, + self._gemm_input_role, + ) + else: + raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py new file mode 100644 index 000000000..16e270574 --- /dev/null +++ b/torchao/float8/float8_python_api.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +This file defines the Python functions for float8 which expect inputs +of class `Float8Tensor`. This is a thin wrapper on top of the aten API +to simplify the product code. +""" + +from typing import Optional + +import torchao.float8.float8_aten_api # noqa + +import torch + + +# [Note] Usage of scales +# The meaning of scale in this library can be found in the definition of the Float8Tensor +# Cublas defines scale to always mean a multiplicative factor for the respective matrices +# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale +# For output going from fp32 -> fp8 we multiply by the scale +def addmm_float8_unwrapped( + a_data: torch.Tensor, + a_scale: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.tensor, + output_dtype: torch.dtype, + output_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + """ + This is the unwrapped version of addmm_float8, which does not take in Float8Tensors + as inputs. This is used to standardize the logic between subclassed and non subclassed + versions of the linear module. + """ + a_inverse_scale = a_scale.reciprocal() + b_inverse_scale = b_scale.reciprocal() + if output_dtype == torch.float32 and bias is not None: + # Bias is not supported by _scaled_mm when output is fp32 + output = torch._scaled_mm( + a_data, + b_data, + scale_a=a_inverse_scale, + scale_b=b_inverse_scale, + scale_result=output_scale, + out_dtype=output_dtype, + use_fast_accum=use_fast_accum, + ) + output += bias + return output + output = torch._scaled_mm( + a_data, + b_data, + scale_a=a_inverse_scale, + scale_b=b_inverse_scale, + bias=bias, + scale_result=output_scale, + out_dtype=output_dtype, + use_fast_accum=use_fast_accum, + ) + return output diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py new file mode 100644 index 000000000..bbf140eff --- /dev/null +++ b/torchao/float8/float8_scaling_utils.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for scaling high precision tensors to float8. +""" + +from typing import Optional + +import torch + +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, + ScaledMMConfig, + tensor_already_casted_to_fp8, +) + +from torchao.float8.float8_utils import ( + amax_history_to_scale, + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + tensor_to_scale, +) + + +def hp_tensor_to_float8_dynamic( + hp_tensor: torch.Tensor, + float8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, +) -> Float8Tensor: + """ + Given a high precision tensor `hp_tensor`, + scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. + + Args: + hp_tensor: the tensor to convert + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + if tensor_already_casted_to_fp8(hp_tensor): + return hp_tensor + scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + return hp_tensor_and_scale_to_float8( + hp_tensor, + scale, + float8_dtype, + linear_mm_config, + gemm_input_role, + ) + + +def hp_tensor_to_float8_delayed( + hp_tensor: torch.Tensor, + s: torch.Tensor, + float8_dtype: torch.dtype, + amax_buffer: torch.Tensor, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, +) -> Float8Tensor: + """ + Given a high precision tensor `hp_tensor` and relevant metadata, scales it using + delayed scaling and returns a `Float8Tensor` of the result. Specifically: + 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace + 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor + + Args: + hp_tensor: the tensor to convert + s: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + amax_buffer.fill_(tensor_to_amax(hp_tensor)) + return hp_tensor_and_scale_to_float8( + hp_tensor, + s, + float8_dtype, + linear_mm_config, + gemm_input_role, + ) + + +def _maybe_initialize_amaxes_scales_for_float8_cast( + x, + cur_amax, + amax_history, + scale, + scale_fn_name, + float8_dtype, + is_initialized, + reduce_amax, +): + """ + If x is about to be cast to `float8` and the amax buffers are not initialized, + initializes them inplace. + """ + if is_initialized: + return + with torch.no_grad(): + # Note: we need to enable distributed reduction here in order + # to match numerics between single GPU and multi GPU code for + # activations and gradients + new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) + cur_amax.fill_(new_amax) + amax_history[0] = new_amax + new_scale = amax_history_to_scale( + amax_history, float8_dtype, x.dtype, scale_fn_name + ) + scale.copy_(new_scale) + + +@torch._dynamo.allow_in_graph +class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): + """ + Forward: no-op + Backward: convert to float8_e5m2 with delayed scaling, initialize if needed + """ + + @staticmethod + def forward( + ctx, + tensor, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + scale_fn_name, + is_amax_initialized, + linear_mm_config: LinearMMConfig, + ): + ctx.save_for_backward( + fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output + ) + ctx.scale_fn_name = scale_fn_name + ctx.is_amax_initialized = is_amax_initialized + ctx.linear_mm_config = linear_mm_config + return tensor + + @staticmethod + def backward(ctx, go): + ( + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + ) = ctx.saved_tensors + scale_fn_name = ctx.scale_fn_name + is_amax_initialized = ctx.is_amax_initialized + + _maybe_initialize_amaxes_scales_for_float8_cast( + go, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + scale_fn_name, + e5m2_dtype, + is_amax_initialized, + reduce_amax=True, + ) + + fp8_amax_grad_output.fill_(tensor_to_amax(go)) + + res = hp_tensor_and_scale_to_float8( + go, + fp8_scale_grad_output, + e5m2_dtype, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, + ) + empty_grads = None, None, None, None, None, None + return res, *empty_grads + + +@torch._dynamo.allow_in_graph +class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): + """ + Forward: no-op + Backward: convert to float8_e5m2 with dynamic scaling + """ + + @staticmethod + def forward( + ctx, + tensor, + linear_mm_config: LinearMMConfig, + ): + ctx.linear_mm_config = linear_mm_config + return tensor + + @staticmethod + def backward(ctx, gradY): + if tensor_already_casted_to_fp8(gradY): + return gradY, None + gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + fp8_tensor = hp_tensor_and_scale_to_float8( + gradY, + gradY_scale, + e5m2_dtype, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, + ) + return fp8_tensor, None diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py new file mode 100644 index 000000000..a858408fe --- /dev/null +++ b/torchao/float8/float8_tensor.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import enum +from collections import namedtuple +from typing import Dict, Optional + +import torch + +import torch.distributed._functional_collectives as funcol +from torchao.float8.float8_utils import ( + e4m3_dtype, + tensor_to_amax, + to_fp8_saturated, +) +from torch.distributed._tensor import DTensor + +aten = torch.ops.aten + +# +# A note on configuration of float8 logic in a linear +# TODO(future): move all the configs to separate file +# TODO(future): change this to input/weight/grad_output notation, +# can be separate PR because none of this is user facing +# +# There are three gemms in a forward + backward of a Linear layer: +# +# 1. input @ weight_t = output (forward pass) +# 2. grad_output @ weight = grad_input (backward pass) +# 3. input_t @ grad_output = grad_weight (backward pass) +# +# In the formulas above, there are: +# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t). +# - Note that grad_output_t is implied because of memory format requirements +# of float8 gemms +# B. three output tensors (output, grad_input, grad_weight) +# +# We want each input tensor, gemm, and output tensor to be configurable. +# The state of this configuration today is: +# +# i. pairs of input tensors (non-t and t variants) have their scaling +# configurable via the scaling_type_* arguments to Float8Linear +# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing +# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed +# to configure all three gemms, also not user facing + + +# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. +# emulate: whether to emulate the matmuls in fp32 +# use_fast_accum: whether to use the fast-accumulation option for scaled_mm +# fp8_output: whether to output the result of the scaled_mm in fp8 +# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16. +ScaledMMConfig = namedtuple( + "ScaledMMConfig", + ["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"], + defaults=[False, False, False, False], +) + +# The object below is not user facing and exists for convenience, +# to allow Float8Tensor to use +# the right config based on which gemm from gemms with outputs +# `output`, `grad_input`, `grad_weight` is +# being called. +LinearMMConfig = namedtuple( + "LinearMMConfig", + ["output", "grad_input", "grad_weight"], + defaults=[ + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ], +) + + +class GemmInputRole(enum.Enum): + """ + Given a Float8Tensor, the enum below describes the expected role of this + tensor in the three gemms present in the fw + bw pass of a Linear layer. + This is used to choose the right config for a float8 gemm when the + gemm is performed. + """ + + INPUT = "input" + WEIGHT = "weight" + GRAD_OUTPUT = "grad_output" + + +# choose which scaled_mm_config to use based on gemm inputs +def choose_scaled_mm_config( + a_role: GemmInputRole, + a_linear_mm_config: LinearMMConfig, + b_role: GemmInputRole, + b_linear_mm_config: LinearMMConfig, +): + if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT: + assert ( + a_linear_mm_config.output == b_linear_mm_config.output + ), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}" + return a_linear_mm_config.output + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT: + assert ( + a_linear_mm_config.grad_input == b_linear_mm_config.grad_input + ), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}" + return a_linear_mm_config.grad_input + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT: + assert ( + a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight + ), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}" + return a_linear_mm_config.grad_weight + else: + raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") + + +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: + """ + Check if the tensor is already casted to fp8 + """ + if isinstance(tensor, Float8Tensor): + return True + elif isinstance(tensor, DTensor): + # TODO: shall we stick to public API and directly use tensor.to_local() here? + return tensor_already_casted_to_fp8(tensor._local_tensor) + elif isinstance(tensor, funcol.AsyncCollectiveTensor): + return tensor_already_casted_to_fp8(tensor.elem) + + return False + + +@torch._dynamo.allow_in_graph +class _ToFloat8ConstrFunc(torch.autograd.Function): + """ + A differentiable conversion to fp8. + * forward: convert from high precision to float8 + * backward: pass the gradient without changes + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype=e4m3_dtype, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + ): + """ + This function will apply the scaling, and then convert to a Float8Tensor + + Note: + We will call this function with a DTensor subclass. Ideally this would be an aten OP + that DTensor could overload to ensure proper semantics. There are some techincal issues + with that composing with FakeTensor, so we special case here. + + DTensor Invariant: DTensor must always be the outer most tensor subclass + """ + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + + if isinstance(bits_fp8, DTensor): + assert isinstance( + scale, DTensor + ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor" + bits_mesh = bits_fp8.device_mesh + bits_placements = bits_fp8.placements + local_bits = bits_fp8.to_local() + local_scale = scale.to_local() + inner_float8_tensor = Float8Tensor( + local_bits, + local_scale, + tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + return DTensor.from_local( + inner_float8_tensor, + bits_mesh, + bits_placements, + run_check=False, + shape=bits_fp8.size(), + stride=bits_fp8.stride(), + ) + + return Float8Tensor( + bits_fp8, + scale, + tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + + @staticmethod + def backward(ctx, g): + return g, None, None, None, None, None + + +@torch._dynamo.allow_in_graph +class _FromFloat8ConstrFunc(torch.autograd.Function): + """ + A differentiable conversion from fp8. + * forward: convert from float8 to high precision + * backward: pass the gradient without changes + """ + + @staticmethod + def forward(ctx, tensor): + return tensor._data.to(tensor._orig_dtype) / tensor._scale + + @staticmethod + def backward(ctx, g): + return g, None, None + + +def hp_tensor_and_scale_to_float8( + hp_tensor: torch.Tensor, + s: torch.Tensor, + float8_dtype=e4m3_dtype, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, +): + """ + Given a high precision tensor `hp_tensor` and a precalculated scale `s`, + scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result. + + Autograd-aware, the derivative is pass-through. + DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor). + + Args: + hp_tensor: the tensor to convert + s: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + return _ToFloat8ConstrFunc.apply( + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role + ) + + +class Float8Tensor(torch.Tensor): + """ + Note: this is **not** a public API and is only intended to be used + inside of this repository. Please file an issue if you would benefit + from this being a public API. + + A Python-only Float8 tensor subclass. Contains: + * `_data`: the underlying e4m3 or e5m2 data + * `_scale`: the scale used to scale the original fp32 tensor. We multiply + by scale to go from fp32 range to fp8 range, and divide by scale to go + from fp8 range to fp32 range. + * `_orig_dtype`: the original dtype of the tensor used to create this + tensor. + * `_emulate`: if true using fp32 emulation for the matmuls, helpful + if you don't have access to h100 hardware. + + Intended usage of this abstraction: + 1. to bundle raw data + fp8 metadata together for easy passing through + Python PyTorch systems. + 2. Float8-aware user code can use the private fields on these tensors + to call into float8 operations. + 3. Float8-agnostic user code can use these tensors as is - they will + convert to original precision in `__torch_dispatch__`. + """ + + _data: torch.Tensor + _scale: torch.Tensor + _orig_dtype: torch.dtype + _linear_mm_config: LinearMMConfig + __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + orig_dtype: torch.dtype, + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + ): + assert ( + scale.numel() == 1 + ), "Scale should contain a single value, but got: {} elements".format( + scale.numel() + ) + + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=orig_dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data = data + self._scale = scale + self._orig_dtype = orig_dtype + self._linear_mm_config = ( + linear_mm_config if linear_mm_config is not None else LinearMMConfig() + ) + self._gemm_input_role = gemm_input_role + + return self + + def __repr__(self): + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + + def __tensor_flatten__(self): + ctx = { + "_orig_dtype": self._orig_dtype, + "_linear_mm_config": self._linear_mm_config, + "_gemm_input_role": self._gemm_input_role, + } + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return Float8Tensor( + inner_tensors["_data"], + inner_tensors["_scale"], + metadata["_orig_dtype"], + metadata["_linear_mm_config"], + metadata["_gemm_input_role"], + ) + + def to_original_precision(self): + return _FromFloat8ConstrFunc.apply(self) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # 1. tracing through __torch_function__ logic is not supported yet in + # PT2.0, so we explicitly disallow it here for callsites from user code. + # 2. We do need to handle a couple of ops in order for + # TorchDynamo tracing to succeed. + + # Lazy import to avoid circular dependency + from torchao.float8.float8_ops import FLOAT8_OPS_TABLE + + # All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs + # And don't support mixed tensor subclasses. This will trigger the handler for + # the next type in the dispatch list + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass( + torch._subclasses.functional_tensor.FunctionalTensor, type + ) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented + + if func in FLOAT8_OPS_TABLE: + return FLOAT8_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError(f"attempting to run {func}, this is not supported") + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py new file mode 100644 index 000000000..affec2a76 --- /dev/null +++ b/torchao/float8/float8_tensor_parallel.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +from torchao.float8.config import ScalingType +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_dynamic, + NoopFwToFloat8E5M2BwDynamic, +) +from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.float8_utils import e4m3_dtype +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) + +# subclass the ColwiseParallel and RowwiseParallel classes +# to add the float8 support +# The parameter sharding stays the same as the core +# ColwiseParallel and RowwiseParallel, the only difference +# here is that in input/output handling we do casting after +# creating the DTensor. + +# NOTE: This only works and tested with the dynamic scaling + + +def _float8_linear_supports_float8_allgather(m): + # TODO(future): add support for delayed scaling for activations + # and gradients + return ( + m.scaling_type_input == ScalingType.DYNAMIC + and m.scaling_type_grad_output == ScalingType.DYNAMIC + ) + + +class Float8ColwiseParallel(ColwiseParallel): + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + e4m3_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) + + # transform the input layouts to the desired layouts of ColwiseParallel + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute( + placements=output_layouts, async_op=True + ) # DTensor(torch.Tensor) + + # fwd noop bwd cast to DTensor(Float8Tensor) + outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + from torchao.float8.float8_linear import Float8Linear + + if not isinstance(module, Float8Linear): + raise ValueError( + f"Expecting module to be Float8Linear but found {type(module)}" + ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") + + return super()._apply(module, device_mesh) + + +class Float8RowwiseParallel(RowwiseParallel): + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + e4m3_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + + # fwd noop bwd cast to DTensor(Float8Tensor) + outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + from torchao.float8.float8_linear import Float8Linear + + if not isinstance(module, Float8Linear): + raise ValueError( + f"Expecting module to be Float8Linear but found {type(module)}" + ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") + + return super()._apply(module, device_mesh) + + +class PrepareFloat8ModuleInput(PrepareModuleInput): + # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that + # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) + # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) + # so that if there are multiple float8 users of the input activation, we perform fp8 allgather + # only once. + # FP8 Args: + # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, + # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn + # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used + # for the float8 cast. If not specified, we will search for the Float8Linear in the submodules + # and use the forward config from that module, in this case all module's forward config must be + # the same. + + def __init__( + self, + *, + input_layouts=None, + desired_input_layouts=None, + input_kwarg_layouts=None, + desired_input_kwarg_layouts=None, + use_local_output=False, + float8_dtype=torch.float8_e4m3fn, + fwd_config_submodule_fqn=None, + ): + super().__init__( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_output, + ) + + # fp8 specific fields + self.float8_dtype = float8_dtype + self.linear_mm_config = None + self.fwd_config_submodule_fqn = fwd_config_submodule_fqn + + if self.float8_dtype != torch.float8_e4m3fn: + raise NotImplementedError( + "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now" + ) + + def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) + + dt_inp = hp_tensor_to_float8_dynamic( + dt_inp, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + from torchao.float8.float8_linear import Float8Linear + + if self.fwd_config_submodule_fqn is not None: + fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) + assert isinstance(fwd_linear, Float8Linear) + self.linear_mm_config = fwd_linear.linear_mm_config + else: + # search for ScaledMM configs for all the submodules and make sure they are the same + for mod in module.modules(): + if isinstance(mod, Float8Linear): + if self.linear_mm_config is None: + self.linear_mm_config = mod.linear_mm_config + else: + assert ( + self.linear_mm_config == mod.linear_mm_config + ), "All the Float8Linear modules should have same linear_mm_config!" + + assert self.linear_mm_config is not None + super()._apply(module, device_mesh) + return module diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py new file mode 100644 index 000000000..1d6c69d17 --- /dev/null +++ b/torchao/float8/float8_utils.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable, Literal, Tuple, Union + +import torchao.float8.config as config + +import torch +import torch.distributed as dist + +# Helpful visualizer for debugging (only supports fp32): +# https://www.h-schmidt.net/FloatConverter/IEEE754.html + +# avoid division by zero when calculating scale +# TODO: align this value with NVIDIA's assumptions (current value is a guess) +EPS = 1e-12 + +IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} + + +# User defined type for using the individual F8 type based on config +e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz +e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz + + +@torch.no_grad() +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + """Converts the amax value of a tensor to the fp8 scale. + Args: + amax: The amax value of the tensor. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + """ + scale = torch.empty_like(amax, dtype=torch.float32) + if float8_dtype in FP8_TYPES: + res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + + # Ensure that the scale is representable in float16, + # this helps when amax is small. We are assuming that we don't need + # to care about this for float32/bfloat16. + if orig_dtype is torch.float16: + res = torch.clamp(res, max=torch.finfo(torch.float16).max) + scale.copy_(res) + return scale + + +@torch.no_grad() +def amax_history_to_scale( + amax_history: torch.Tensor, + float8_dtype: torch.Tensor, + orig_dtype: torch.dtype, + history_to_scale_fn_type: Literal["max"], +): + """Takes in a history of amax values and returns a scale tensor. + Args: + amax_history: A tensor containing the history of amax values. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ + if history_to_scale_fn_type == "max": + amax = torch.max(amax_history) + return amax_to_scale(amax, float8_dtype, orig_dtype) + raise NotImplementedError() + + +@torch.no_grad() +def amax_history_to_scale_stack( + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: Literal["max"], +) -> torch.Tensor: + """Takes in a stack of amax_history tensors and returns a scale tensor. + Args: + amax_history: A 2D tensor containing a stack of amax histories. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ + if history_to_scale_fn_type == "max": + amax_stack = torch.max(amax_history, dim=1).values + return amax_to_scale(amax_stack, float8_dtype, orig_dtype) + raise NotImplementedError( + f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" + ) + + +@torch.no_grad() +def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: + amax = torch.max(torch.abs(x)) + + # If the user asked for distributed reduction, do it. + # If the user did not ask for it, assume that it will + # happen elsewhere. + if reduce_amax and dist.is_initialized(): + dist.all_reduce(amax, op=dist.ReduceOp.MAX) + + return amax + + +@torch.no_grad() +def tensor_to_scale( + x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False +) -> torch.Tensor: + amax = tensor_to_amax(x, reduce_amax=reduce_amax) + return amax_to_scale(amax, float8_dtype, x.dtype) + + +def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): + """Converts a tensor to a saturated fp8 tensor. + + Note: + The default behavior in PyTorch for casting to `float8_e4m3fn` + and `e5m2` is to not saturate. In this context, we should saturate. + A common case where we want to saturate is when the history of a + tensor has a maximum value of `amax1`, and the current amax value + is `amax2`, where `amax1 < amax2`. This is common when using delayed + scaling. + """ + if float8_dtype in FP8_TYPES: + max_value = torch.finfo(float8_dtype).max + x = x.clamp(min=-max_value, max=max_value) + return x.to(float8_dtype) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + + +def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the error between two tensors in dB. + + For more details see: + https://en.wikipedia.org/wiki/Signal-to-noise_ratio + + Args: + x: The original tensor. + y: The tensor to compare to the original tensor. + """ + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +def fp8_tensor_statistics( + tensor: torch.Tensor, float8_dtype=e4m3_dtype +) -> Tuple[int, ...]: + """Calculate FP8 tensor stats + + Args: + tensor: The tensor to calculate stats for. + float8_dtype: The float8 dtype. + + Returns: + A tuple containing the number of zeros and the number of max values. + """ + if float8_dtype in FP8_TYPES: + FP8_MAX = torch.finfo(float8_dtype).max + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype) + num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item() + num_zero = (tensor_orig_type == 0).sum().item() + return (num_zero, num_max) + + +def is_row_major(stride): + assert len(stride) == 2, "is_row_major only supports 2D tensors" + return stride[0] > stride[1] and stride[1] == 1 + + +def _get_min_alignment(size: int, alignment_value: int) -> int: + """ + Returns the minimum alignment value that is greater than or equal to the given size. + + Args: + size: The size of the data to be aligned. + alignment_value: The alignment value to be used. + + Returns: + int: The minimum alignment value that is greater than or equal to the given size. + + Usage: + ``` + >>> _get_min_alignment(10, 8) + 16 + ``` + """ + if size % alignment_value == 0: + return size + return (1 + (size // alignment_value)) * alignment_value + + +def pad_tensor_for_matmul( + tensor: torch.Tensor, dims: Union[int, Iterable[int]] +) -> torch.Tensor: + """ + Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm` + + Args: + tensor: The tensor to pad. + both: Whether to pad both dimensions or just the second dimension. + + Returns: + torch.Tensor: The padded tensor. + + Usage: + ``` + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape + torch.Size([16, 10]) + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape + torch.Size([10, 16]) + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape + torch.Size([16, 16]) + ``` + """ + assert tensor.dim() == 2 + dim1, dim2 = tensor.shape + + if isinstance(dims, int): + dims = (dims,) + + # Calculate aligned dimensions based on the specified dims + dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 + dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 + + # Check if padding is needed for either dimension + if dim1 == dim1_aligned and dim2 == dim2_aligned: + return tensor + + # Calculate padding values for both dimensions + pad_dim1 = dim1_aligned - dim1 + pad_dim2 = dim2_aligned - dim2 + + return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py new file mode 100644 index 000000000..5f53f5d82 --- /dev/null +++ b/torchao/float8/fsdp_utils.py @@ -0,0 +1,388 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.utils._pytree as pytree +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, +) + +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, +) + +from torchao.float8.float8_utils import e4m3_dtype, EPS +from torch._prims_common import suggest_memory_format + + +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: + """ + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: + model(input).sum().backward() + optim.step() + precompute_float8_dynamic_scale_for_fsdp(model) + """ + from torchao.float8.config import ScalingType + from torchao.float8.float8_linear import Float8Linear + from torch.distributed._tensor import DTensor + + if any( + isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + float8_linears: List[Float8Linear] = [ + m + for m in module.modules() + if isinstance(m, Float8Linear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + if not weights: + return + + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.stack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + local_scale_tensor = scale_tensor.to_local() + for i, float8_linear in enumerate(float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] + + +# FSDP pads its local tensor on dim-0. The subclass should be preserved such +# that the padded local tensor (and any transformations like copying to GPU) +# is of the subclass as well. +_ops_to_preserve_subclass = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, +} + + +class WeightWithDynamicFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__( + cls, + tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__( + self, + tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): + self._tensor = tensor + self._linear_mm_config = linear_mm_config + # for dynamic scaling + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales + # for all float8 parameters after optimizer step + self._precomputed_scale = precomputed_scale + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDynamicFloat8CastTensor( + args[0]._tensor, args[0]._linear_mm_config + ) + mm_config: Optional[LinearMMConfig] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._linear_mm_config + else: + assert t._linear_mm_config == mm_config + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + ) + + def __tensor_flatten__(self): + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._linear_mm_config + else: + return ["_tensor"], self._linear_mm_config + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + mm_config = flatten_spec + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_precomputed_scale", None), + ) + + def __repr__(self): + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" + + def fsdp_pre_all_gather(self, mesh): + if self._precomputed_scale is not None: + float8_tensor = hp_tensor_and_scale_to_float8( + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + self._linear_mm_config, + GemmInputRole.WEIGHT, + ) + else: + float8_tensor = hp_tensor_to_float8_dynamic( + self._tensor, + e4m3_dtype, + self._linear_mm_config, + reduce_amax=True, + gemm_input_role=GemmInputRole.WEIGHT, + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ), (data,) + + +class WeightWithDelayedFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__( + cls, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + linear_mm_config: LinearMMConfig, + is_amax_initialized: bool, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__( + self, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + linear_mm_config: LinearMMConfig, + is_amax_initialized: bool, + ): + self._tensor = tensor + self._amax_buffer = amax_buffer + self._amax_history_buffer = amax_history_buffer + self._scale_buffer = scale_buffer + self._linear_mm_config = linear_mm_config + + # Note: is_amax_initialized is not a buffer to avoid data dependent + # control flow visible to dynamo + # TODO(future PR): add serialization for this flag + self.is_amax_initialized = is_amax_initialized + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDelayedFloat8CastTensor( + args[0]._tensor, + args[0]._amax_buffer, + args[0]._amax_history_buffer, + args[0]._scale_buffer, + args[0]._linear_mm_config, + args[0].is_amax_initialized, + ) + mm_config: Optional[LinearMMConfig] = None + amax_buffer: Optional[torch.Tensor] = None + amax_history_buffer: Optional[torch.Tensor] = None + scale_buffer: Optional[torch.Tensor] = None + is_amax_initialized: Optional[bool] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._linear_mm_config + else: + assert t._linear_mm_config == mm_config + nonlocal amax_buffer + if amax_buffer is None: + amax_buffer = t._amax_buffer + nonlocal amax_history_buffer + if amax_history_buffer is None: + amax_history_buffer = t._amax_history_buffer + nonlocal scale_buffer + if scale_buffer is None: + scale_buffer = t._scale_buffer + nonlocal is_amax_initialized + if is_amax_initialized is None: + is_amax_initialized = t.is_amax_initialized + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, + lambda x: WeightWithDelayedFloat8CastTensor( + x, + amax_buffer, + amax_history_buffer, + scale_buffer, + mm_config, + is_amax_initialized, + ), + out, + ) + + def __tensor_flatten__(self): + return ( + [ + "_tensor", + "_amax_buffer", + "_amax_history_buffer", + "_scale_buffer", + ], + { + "mm_config": self._linear_mm_config, + "is_amax_initialized": self.is_amax_initialized, + }, + ) + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + return WeightWithDelayedFloat8CastTensor( + inner_tensors["_tensor"], + inner_tensors["_amax_buffer"], + inner_tensors["_amax_history_buffer"], + inner_tensors["_scale_buffer"], + metadata["mm_config"], + metadata["is_amax_initialized"], + ) + + def __repr__(self): + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" + + def fsdp_pre_all_gather(self, mesh): + # initialize if needed + # TODO(before land): ensure settings are consistent between Float8Linear and here + if not self.is_amax_initialized: + from torchao.float8.float8_linear import ( + _maybe_initialize_amaxes_scales_for_float8_cast, + ) + + _maybe_initialize_amaxes_scales_for_float8_cast( + self._tensor, + self._amax_buffer, + self._amax_history_buffer, + self._scale_buffer, + "max", # TODO(before land): read this from parent + e4m3_dtype, + self.is_amax_initialized, + reduce_amax=True, + ) + self.is_amax_initialized = True + + float8_tensor = hp_tensor_to_float8_delayed( + self._tensor, + self._scale_buffer, + e4m3_dtype, + self._amax_buffer, + self._linear_mm_config, + GemmInputRole.WEIGHT, + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ), (data,) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py new file mode 100644 index 000000000..f5c504503 --- /dev/null +++ b/torchao/float8/inference.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Defines an nn module designed to be used during inference +""" + +from dataclasses import dataclass + +from enum import auto, Enum +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from torchao.float8.float8_linear_utils import swap_linear_layers + +from torchao.float8.float8_tensor import ( + Float8Tensor, + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, + ScaledMMConfig, + tensor_already_casted_to_fp8, +) +from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale + + +class ActivationCasting(Enum): + """Types of quantization to perform on the activations + + WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass + STATIC: Activation is quantized during model initialization with a static scale + DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation + """ + + # TODO: A better name would be NONE, we should unify this with torchao + WEIGHT_ONLY = auto() + DYNAMIC = auto() + STATIC = auto() + + +@dataclass(frozen=True) +class QuantConfig: + """Defines the configuration for the quantization to fp8 of a linear module + + Args: + activation_casting: The type of quantization to perform on the activations + static_quantization_scale: The scale of the input to this linear module, used for static quantization only + """ + + activation_casting: ActivationCasting + static_quantization_scale: Optional[torch.Tensor] = None + + # If True, then prior to performing the fp8 scaled mamtmul we will pad the + # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls + # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. + # This can cause a memory spike however so we keep this off by default. + pad_inner_dim = False + + def __post_init__(self): + if self.activation_casting == ActivationCasting.STATIC: + assert isinstance( + self.static_quantization_scale, torch.Tensor + ), "When activation_casting is 'static', activation_scale must be a tensor." + + +class Float8InferenceLinear(torch.nn.Linear): + """ + This is a wrapper around torch.nn.Linear that supports FP8 inference + Supported forms of inference: + - FP8 inference with high precision matmul - weight only + - FP8 inference with fp8 matmul and dynamic weight casting + - FP8 inference with fp8 matmul and static weight casting + """ + + def __init__( + self, + # FP8 specific arguments + quant_config: QuantConfig, + linear_mm_config: LinearMMConfig, + # nn.Linear arguments + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + # Construct the superclass this will create dummy weights and biases + super().__init__(in_features, out_features, bias, device, dtype) + self.linear_mm_config = linear_mm_config + self.activation_casting = quant_config.activation_casting + if self.activation_casting == ActivationCasting.STATIC: + self.register_buffer( + "static_quantization_scale", quant_config.static_quantization_scale + ) + else: + self.static_quantization_scale = None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.activation_casting == ActivationCasting.WEIGHT_ONLY: + return torch.nn.functional.linear( + input, self.weight.to_original_precision() + ) + + x_fp8 = cast_to_float8_e4m3_inference( + input, + self.linear_mm_config, + static_quantization_scale=self.static_quantization_scale, + ) + return torch.nn.functional.linear(x_fp8, self.weight, self.bias) + + # Builder functions for Float8LinearInference + def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: + """This functions converts the weight to a Float8Tensor and sets its requires_grad to False. + + Args: + dtype: The dtype to quantize the weight to. Default is e4m3_dtype. + + Note: + This function is typically called during inference to quantize the weight once since + the weight is not updated during inference. + + """ + assert not isinstance( + self.weight, Float8Tensor + ), "Weight has already been quantized, cannot quantize again." + scale = tensor_to_scale(self.weight, dtype) + quantized_weight = hp_tensor_and_scale_to_float8( + self.weight, + scale, + dtype, + self.linear_mm_config, + GemmInputRole.WEIGHT, + ) + self.weight = nn.Parameter(quantized_weight) + self.weight.requires_grad = False + + def set_weight_and_bias( + self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter] + ): + self.weight = weight + self.bias = bias + + @classmethod + def from_float( + cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool + ) -> "Float8InferenceLinear": + """ + Create an nn.Linear with fp8 compute from another nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + quant_config (QuantConfig): Configuration for the weight and activation casting + """ + forward_config = ScaledMMConfig( + False, use_fast_accum, pad_inner_dim=quant_config.pad_inner_dim + ) + linear_mm_config = LinearMMConfig( + forward_config, forward_config, forward_config + ) + linear = cls( + quant_config, + linear_mm_config, + module.in_features, + module.out_features, + False, + device=torch.device("meta"), + ) + linear.set_weight_and_bias(module.weight, module.bias) + linear.quantize_weight() + return linear + + +def cast_to_float8_e4m3_inference( + inpt_tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + static_quantization_scale: Optional[torch.Tensor] = None, +) -> Float8Tensor: + """Casts an input tensor to the Float8 (e4m3fn*) + + Args: + inpt_tensor: The input tensor to be cast. + linear_mm_config: Configuration settings for the matrix multiplication + reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. + static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. + + Returns: + Float8Tensor: The input tensor cast to Float8 (e4m3fn) format. + + Note: + If the input tensor is already in Float8 format, it is returned as is without re-casting. + """ + if tensor_already_casted_to_fp8(inpt_tensor): + return inpt_tensor + scale = ( + static_quantization_scale + if static_quantization_scale is not None + else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + ) + return hp_tensor_and_scale_to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config, + GemmInputRole.INPUT, + ) + + +def quantize_to_float8( + module: nn.Module, + quant_config: QuantConfig, + *, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, + use_fast_accum: bool = True, +) -> nn.Module: + """ + Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead + + Args: + module (nn.Module): The module to modify. + quant_config (QuantConfig): Quantization configuration for Float8 conversion. + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that + that pass the filter function will be swapped. The inputs to the + filter function are the module instance and the FQN. + use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. + + Returns: + nn.Module: The modified module with applicable Linear layers converted to Float8. + + Raises: + AssertionError: If a root-level nn.Linear with children is encountered. + """ + return swap_linear_layers( + module, + lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), + module_filter_fn=module_filter_fn, + )