From a4a1f74890a8719382fc2fb47ca93a8fd2a3841d Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 6 Jan 2025 10:26:35 -0800 Subject: [PATCH 1/6] integrate float8nocompile, an experimental feature for high performance float8 training in eager mode --- .python-version | 1 + torchtitan/config_manager.py | 5 +++++ torchtitan/float8.py | 30 +++++++++++++++++++++++------- train_configs/llama3_8b.toml | 1 + 4 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 .python-version diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..8cc1b46f --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10.15 diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d59e34bc..2177000d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -567,6 +567,11 @@ def __init__(self): default="dynamic", help="float8 scaling for input, dynamic (default) or delayed", ) + self.parser.add_argument( + "--float8.no_compile", + action="store_true", + help="use the float8nocompile prototype implementation", + ) # communications library settings self.parser.add_argument( diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 1dd0d0bb..638c1ddc 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -47,6 +47,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use float8 linear layers." ) from e + self.use_float8nocompile = float8_config.no_compile + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( parallel_dims.dp_shard_enabled @@ -90,14 +92,28 @@ def convert_to_float8_training(self, model: nn.Module): if not self.enabled: return - from torchao.float8 import convert_to_float8_training + # TODO: should we implicitly use this if self.compile is False, rather + # than having an explicit flag? + if self.use_float8nocompile: + logger.info("Using float8nocompile prototype") + from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( + convert_to_float8_nocompile_training, + ) - # Mutates the model inplace replacing instances of nn.Linear with Float8Linear - convert_to_float8_training( - model, - config=self.config, - module_filter_fn=lambda mod, fqn: fqn != "output", - ) + convert_to_float8_nocompile_training( + model, + config=self.config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) + else: + from torchao.float8 import convert_to_float8_training + + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear + convert_to_float8_training( + model, + config=self.config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) logger.info( "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" f"{self.config.enable_fsdp_float8_all_gather}" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3001ec74..9873d129 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -56,3 +56,4 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba [float8] enable_float8_linear = false +no_compile = false # TODO: should this go in [experimental]? From 9e19e7fa91933e5c3010a215813e5237ce30c5b1 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 10 Jan 2025 09:39:08 -0800 Subject: [PATCH 2/6] add options to use AC in float8nocompile linear layers; add support for only compiling linear layers --- torchtitan/config_manager.py | 7 ++++++- torchtitan/float8.py | 5 ++++- torchtitan/parallelisms/parallelize_llama.py | 21 +++++++++++++++----- train_configs/llama3_8b.toml | 3 ++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2177000d..1fc6916b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -568,10 +568,15 @@ def __init__(self): help="float8 scaling for input, dynamic (default) or delayed", ) self.parser.add_argument( - "--float8.no_compile", + "--float8.float8nocompile", action="store_true", help="use the float8nocompile prototype implementation", ) + self.parser.add_argument( + "--float8.float8nocompile_ac", + action="store_true", + help="use activation checkpointing with float8nocompile linear layers", + ) # communications library settings self.parser.add_argument( diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 638c1ddc..20540d7b 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -47,7 +47,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use float8 linear layers." ) from e - self.use_float8nocompile = float8_config.no_compile + self.use_float8nocompile = float8_config.float8nocompile + self.use_float8nocompile_ac = float8_config.float8nocompile_ac # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -104,8 +105,10 @@ def convert_to_float8_training(self, model: nn.Module): model, config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", + use_activation_checkpointing=self.use_float8nocompile_ac, ) else: + logger.info("Using float8 training") from torchao.float8 import convert_to_float8_training # Mutates the model inplace replacing instances of nn.Linear with Float8Linear diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9728569a..f2888350 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -7,6 +7,7 @@ # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. +import os from collections import defaultdict import torch @@ -299,11 +300,21 @@ def apply_compile(model: nn.Module): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) - model.layers.register_module(layer_id, transformer_block) - - logger.info("Compiling each TransformerBlock with torch.compile") + compile_linear_only = bool(os.environ.get("TORCHTITAN_COMPILE_LINEAR_ONLY", False)) + + if compile_linear_only: + logger.info("Compiling linear layers with torch.compile") + for name, child in model.named_children(): + if isinstance(child, torch.nn.Linear): + new_child = torch.compile(child) + setattr(model, name, new_child) + else: + apply_compile(child) + else: + logger.info("Compiling each TransformerBlock with torch.compile") + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) def apply_fsdp( diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 9873d129..7890eb0c 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -56,4 +56,5 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba [float8] enable_float8_linear = false -no_compile = false # TODO: should this go in [experimental]? +float8nocompile = false # TODO: should this go in [experimental]? +float8nocompile_ac = false From 73715c679b76ee0f663f4f02f2bd8a1e6beea0d4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 13 Jan 2025 13:29:01 -0800 Subject: [PATCH 3/6] rename ac -> no_precompute_for_backward to avoid confusion with torch native ac --- torchtitan/config_manager.py | 2 +- torchtitan/float8.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1fc6916b..ed4a3a0b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -573,7 +573,7 @@ def __init__(self): help="use the float8nocompile prototype implementation", ) self.parser.add_argument( - "--float8.float8nocompile_ac", + "--float8.float8nocompile_no_precompute_for_backward", action="store_true", help="use activation checkpointing with float8nocompile linear layers", ) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 20540d7b..d5aa8f28 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -48,7 +48,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) from e self.use_float8nocompile = float8_config.float8nocompile - self.use_float8nocompile_ac = float8_config.float8nocompile_ac + self.use_float8nocompile_no_precompute_for_backward = ( + float8_config.float8nocompile_no_precompute_for_backward + ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -105,7 +107,7 @@ def convert_to_float8_training(self, model: nn.Module): model, config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", - use_activation_checkpointing=self.use_float8nocompile_ac, + no_precompute_for_backward=self.use_float8nocompile_no_precompute_for_backward, ) else: logger.info("Using float8 training") From cbedb73f80b3ee038fe04e7617dcbfd4b4910567 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 14 Jan 2025 09:54:52 -0800 Subject: [PATCH 4/6] add handling for selective per layer ac in float8nocompile --- torchtitan/float8.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index d5aa8f28..b9d60959 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -13,7 +13,7 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -from typing import List, Union +from typing import Callable, List, Union import torch import torch.nn as nn @@ -48,9 +48,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) from e self.use_float8nocompile = float8_config.float8nocompile - self.use_float8nocompile_no_precompute_for_backward = ( - float8_config.float8nocompile_no_precompute_for_backward - ) + self.ac_config = job_config.activation_checkpoint # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -95,20 +93,30 @@ def convert_to_float8_training(self, model: nn.Module): if not self.enabled: return - # TODO: should we implicitly use this if self.compile is False, rather - # than having an explicit flag? if self.use_float8nocompile: logger.info("Using float8nocompile prototype") from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) + # for full AC or no AC + no_precompute_for_backward = self.ac_config.mode == "full" convert_to_float8_nocompile_training( model, config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", - no_precompute_for_backward=self.use_float8nocompile_no_precompute_for_backward, + no_precompute_for_backward=no_precompute_for_backward, ) + + # for selective per layer AC + if ( + self.ac_config.mode == "selective" + and self.ac_config.selective_ac_option.isdigit() + ): + no_precompute_for_backward_every_nth_layer( + model, + int(self.ac_config.selective_ac_option), + ) else: logger.info("Using float8 training") from torchao.float8 import convert_to_float8_training @@ -166,3 +174,22 @@ def sync_float8_amax_and_scale_history( models = [model] if isinstance(model, nn.Module) else model for m in models: self._sync_float8_amax_and_scale_history(m) + + +def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int): + """Set no_precompute_for_backward to True for every nth layer in the model.""" + for layer_idx, (layer_id, transformer_block) in enumerate( + model.layers.named_children() + ): + if layer_idx % n == 0: + logger.info(f"Enabling no_precompute_for_backward to layer {layer_id}") + _enable_no_precompute_for_backward(transformer_block) + + +def _enable_no_precompute_for_backward(model: nn.Module): + """Recursively set no_precompute_for_backward to True for all linear layers in the given model.""" + for layer in model.children(): + if isinstance(layer, nn.Linear): + layer.no_precompute_for_backward = True + else: + _enable_no_precompute_for_backward(layer) From 8b2979832330b5ab09a597b8848d4f2bf25700f2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 14 Jan 2025 13:02:34 -0800 Subject: [PATCH 5/6] move helper to float8nocompile as it is unique to the prototype --- torchtitan/float8.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index b9d60959..97f09aee 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -97,6 +97,7 @@ def convert_to_float8_training(self, model: nn.Module): logger.info("Using float8nocompile prototype") from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, + no_precompute_for_backward_every_nth_layer, ) # for full AC or no AC @@ -174,22 +175,3 @@ def sync_float8_amax_and_scale_history( models = [model] if isinstance(model, nn.Module) else model for m in models: self._sync_float8_amax_and_scale_history(m) - - -def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int): - """Set no_precompute_for_backward to True for every nth layer in the model.""" - for layer_idx, (layer_id, transformer_block) in enumerate( - model.layers.named_children() - ): - if layer_idx % n == 0: - logger.info(f"Enabling no_precompute_for_backward to layer {layer_id}") - _enable_no_precompute_for_backward(transformer_block) - - -def _enable_no_precompute_for_backward(model: nn.Module): - """Recursively set no_precompute_for_backward to True for all linear layers in the given model.""" - for layer in model.children(): - if isinstance(layer, nn.Linear): - layer.no_precompute_for_backward = True - else: - _enable_no_precompute_for_backward(layer) From 4ab683cf15c7f7060bc4dce0ed9eea9f18f44b77 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 15 Jan 2025 16:52:20 -0800 Subject: [PATCH 6/6] use helpers for selective per layer ac --- torchtitan/float8.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 97f09aee..f8aac538 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -97,7 +97,6 @@ def convert_to_float8_training(self, model: nn.Module): logger.info("Using float8nocompile prototype") from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, - no_precompute_for_backward_every_nth_layer, ) # for full AC or no AC @@ -175,3 +174,20 @@ def sync_float8_amax_and_scale_history( models = [model] if isinstance(model, nn.Module) else model for m in models: self._sync_float8_amax_and_scale_history(m) + + +def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int): + """Set no_precompute_for_backward to True for every nth layer in the model.""" + for layer_idx, (layer_id, layer) in enumerate(model.layers.named_children()): + if layer_idx % n == 0: + logger.info(f"Enabling no_precompute_for_backward for layer {layer_id}") + _enable_no_precompute_for_backward(layer) + + +def _enable_no_precompute_for_backward(model: nn.Module): + """Recursively set no_precompute_for_backward to True for all linear layers in the given model.""" + for child_layer in model.children(): + if isinstance(child_layer, nn.Linear): + child_layer.no_precompute_for_backward = True + else: + _enable_no_precompute_for_backward(child_layer)