44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # [Note] Getting the 'float8_experimental ' package:
8- # This script requires the 'float8_experimental ' package to function correctly.
7+ # [Note] Getting the 'torchao ' package:
8+ # This script requires the 'torchao ' package to function correctly.
99# Please ensure you have this package installed from the appropriate repository.
10- # You can obtain it from https://github.com/pytorch-labs/float8_experimental.
11- # Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git`
10+ # You can obtain it from https://github.com/pytorch/ao by following the
11+ # installation instructions.
1212
1313# Note: Performance
1414# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15- import contextlib
1615import functools
1716from typing import Optional
1817
2423from torchtitan .logging_utils import logger
2524
2625
27- @contextlib .contextmanager
28- def set_enable_fsdp_float8_all_gather (enable_fsdp_fp8_all_gather : bool ):
29- import float8_experimental .config as config
30-
31- prev = config .enable_fsdp_fp8_all_gather
32- torch .distributed .barrier ()
33- config .enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
34- try :
35- yield
36- finally :
37- torch .distributed .barrier ()
38- config .enable_fsdp_fp8_all_gather = prev
39-
40-
4126@functools .lru_cache (None )
4227def is_sm90_or_later ():
4328 # Float8 is only supported on H100+ GPUs
@@ -63,25 +48,42 @@ def maybe_build_fp8_linear(
6348 )
6449 return
6550 try :
66- from float8_experimental .float8_linear import TensorScalingType
67- from float8_experimental .float8_linear_utils import (
68- swap_linear_with_float8_linear ,
51+ from torchao .float8 import (
52+ CastConfig ,
53+ convert_to_float8_training ,
54+ Float8LinearConfig ,
55+ ScalingType ,
6956 )
7057
7158 # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
7259 enable_fsdp_float8_all_gather = (
7360 job_config .training .enable_fsdp_float8_all_gather and dp_enabled
7461 )
75- with set_enable_fsdp_float8_all_gather (enable_fsdp_float8_all_gather ):
76- swap_linear_with_float8_linear (
77- model , scaling_type_w = TensorScalingType .DYNAMIC
78- )
62+ scaling_type_input = ScalingType (job_config .training .float8_scaling_type_input )
63+ scaling_type_weight = ScalingType (
64+ job_config .training .float8_scaling_type_weight
65+ )
66+ scaling_type_grad_output = ScalingType (
67+ job_config .training .float8_scaling_type_grad_output
68+ )
69+ float8_config = Float8LinearConfig (
70+ enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather ,
71+ cast_config_input = CastConfig (scaling_type = scaling_type_input ),
72+ cast_config_weight = CastConfig (scaling_type = scaling_type_weight ),
73+ cast_config_grad_output = CastConfig (scaling_type = scaling_type_grad_output ),
74+ enable_pre_and_post_forward = False ,
75+ )
76+ convert_to_float8_training (
77+ model ,
78+ config = float8_config ,
79+ module_filter_fn = lambda mod , fqn : fqn != "output" ,
80+ )
7981 logger .info (
8082 f"Swapped to Float8Linear layers with { enable_fsdp_float8_all_gather = } "
8183 )
8284 except ImportError as exc :
8385 raise ImportError (
84- "float8_experimental is not installed. Please install it to use fp8 linear layers."
86+ "torchao is not installed. Please install it to use fp8 linear layers."
8587 ) from exc
8688
8789
@@ -100,6 +102,37 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
100102 "Skipped precomputing fp8 scales because SM90 or later is not available" ,
101103 )
102104 return
103- from float8_experimental . fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
105+ from torchao . float8 import precompute_float8_dynamic_scale_for_fsdp
104106
105107 precompute_float8_dynamic_scale_for_fsdp (model )
108+
109+
110+ _sync_float8_amax_and_scale_history = None
111+
112+
113+ def maybe_sync_float8_amax_and_scale_history (model : nn .Module , job_config : JobConfig ):
114+ if not (
115+ job_config .training .enable_float8_linear
116+ and (
117+ job_config .training .float8_scaling_type_input == "delayed"
118+ or job_config .training .float8_scaling_type_weight == "delayed"
119+ or job_config .training .float8_scaling_type_grad_output == "delayed"
120+ )
121+ ):
122+ return
123+
124+ from torchao .float8 import sync_float8_amax_and_scale_history
125+
126+ # TODO(future): see if precalculating the modules to sync over is going to
127+ # meaningfully help performance
128+
129+ global _sync_float8_amax_and_scale_history
130+ if _sync_float8_amax_and_scale_history is None :
131+ if job_config .training .compile :
132+ _sync_float8_amax_and_scale_history = torch .compile (
133+ sync_float8_amax_and_scale_history
134+ )
135+ else :
136+ _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
137+
138+ sync_float8_amax_and_scale_history (model )
0 commit comments