diff --git a/test_runner.py b/test_runner.py index f2f80504..c84ca6af 100755 --- a/test_runner.py +++ b/test_runner.py @@ -276,34 +276,34 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.enable_fp8_linear", + "--training.enable_float8_linear", ] ], "FSDP2 with original dtype", - "fp8_fsdp2_orig_all_gather", + "float8_fsdp2_orig_all_gather", ngpu=4, ), OverrideDefinitions( [ [ - "--training.enable_fp8_linear", - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", ] ], - "FSDP2 with fp8 all-gather", - "fsdp2_fp8_all_gather", + "FSDP2 with float8 all-gather", + "fsdp2_float8_all_gather", ngpu=4, ), OverrideDefinitions( [ [ - "--training.enable_fp8_linear", - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", "--training.precompute_float8_dynamic_scale_for_fsdp", ] ], - "FSDP2 with fp8 all-gather and precomputed dynamic scales", - "fsdp2_fp8_all_gather_precompute_dynamic_scales", + "FSDP2 with float8 all-gather and precomputed dynamic scales", + "fsdp2_float8_all_gather_precompute_dynamic_scales", ngpu=4, ), ] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 0dfe1bb0..2bd6e370 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ def __init__(self): help="Whether to compile the model", ) self.parser.add_argument( - "--training.enable_fp8_linear", + "--training.enable_float8_linear", action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear` with @@ -348,16 +348,16 @@ def __init__(self): """, ) self.parser.add_argument( - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_fsdp_float8_all_gather", action="store_true", default=False, - help="Whether enable fp8 all-gather in FSDP", + help="Whether enable float8 all-gather in FSDP", ) self.parser.add_argument( "--training.precompute_float8_dynamic_scale_for_fsdp", action="store_true", default=False, - help="Whether precompute fp8 scales dynamically for FSDP", + help="Whether precompute float8 scales dynamically for FSDP", ) self.parser.add_argument( "--training.gc_freq", diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 9b92400c..50c971ae 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -25,7 +25,7 @@ @contextlib.contextmanager -def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): +def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): import float8_experimental.config as config prev = config.enable_fsdp_fp8_all_gather @@ -53,8 +53,8 @@ def maybe_build_fp8_linear( This will mutate the model inplace. """ - enable_fp8_linear = job_config.training.enable_fp8_linear - if not enable_fp8_linear: + enable_float8_linear = job_config.training.enable_float8_linear + if not enable_float8_linear: return if not is_sm90_or_later(): warning_once( @@ -69,15 +69,15 @@ def maybe_build_fp8_linear( ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - enable_fsdp_fp8_all_gather = ( - job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + enable_fsdp_float8_all_gather = ( + job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) - with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): swap_linear_with_float8_linear( model, scaling_type_w=TensorScalingType.DYNAMIC ) logger.info( - f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" + f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" ) except ImportError as exc: raise ImportError( @@ -89,8 +89,8 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( model: nn.Module, job_config: JobConfig ): if not ( - job_config.training.enable_fp8_linear - and job_config.training.enable_fsdp_fp8_all_gather + job_config.training.enable_float8_linear + and job_config.training.enable_fsdp_float8_all_gather and job_config.training.precompute_float8_dynamic_scale_for_fsdp ): return diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index b33e8870..ec0f6763 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -123,18 +123,10 @@ def get_tp_parallel_strategy( This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.enable_fp8_linear: - from float8_experimental.float8_linear import Float8Linear, TensorScalingType - - if any( - isinstance(m, Float8Linear) - and m.scaling_type_w is TensorScalingType.DELAYED - for m in model.modules() - ): - raise NotImplementedError( - "1D TP fp8 all-gather only supports dynamic scaling" - ) - + if job_config.training.enable_float8_linear: + # TODO(future PR): once float8 configuration supports delayed + # scaling, add a check here to enforce supported float8 all-gather + # configurations from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 6064ced1..7c849976 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index f4061ad0..2dc29f2e 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 19e033b8..f17496c5 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 95d67667..69ae7285 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index ac6b31c1..660f2c0b 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 2c3c6e63..7e5ac63c 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4"