-
Notifications
You must be signed in to change notification settings - Fork 679
Enable FP8 full finetune distributed #2546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import pytest | ||
|
|
||
| import torch | ||
|
|
||
| from torchao.float8.float8_linear import Float8Linear | ||
|
|
||
| from torchtune.models.llama3 import base_llama_tp_plan | ||
| from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan | ||
| from torchtune.training.quantization import ( | ||
| _validate_float8_tp_plan, | ||
| convert_to_float8_training, | ||
| is_fp8_tensorwise_scaling, | ||
| ) | ||
|
|
||
|
|
||
| class M(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) | ||
| self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float) | ||
|
|
||
| def example_inputs(self): | ||
| return (torch.randn(1, 512).to(torch.float),) | ||
|
|
||
| def forward(self, x): | ||
| x = self.linear(x) | ||
| x = self.output(x) | ||
| return x | ||
|
|
||
|
|
||
| class TestFloat8: | ||
| def test_convert_to_float8_training(self): | ||
| """ | ||
| Test that target linear layers are converted to Float8Linear. | ||
| """ | ||
| m = M() | ||
| example_inputs = torch.randn(1, 512).to(torch.float) | ||
| m = convert_to_float8_training(m) | ||
| assert isinstance(m.linear, Float8Linear) | ||
| assert not isinstance(m.output, Float8Linear) | ||
| with pytest.raises(Exception): | ||
| m = convert_to_float8_training(m, "unrecognized_recipe_name") | ||
|
|
||
| # TODO: enable when FP8 + TP is supported | ||
| def _test_validate_float8_tp_plan(self): | ||
| """ | ||
| Test that only float8 TP plan is only valid for "tensorwise" float8 recipes. | ||
| """ | ||
| _validate_float8_tp_plan(base_llama_tp_plan()) | ||
| _validate_float8_tp_plan(base_llama_tp_plan(), "anything") | ||
| _validate_float8_tp_plan(_fp8_llama_tp_plan()) | ||
| _validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise") | ||
| with pytest.raises(ValueError): | ||
| _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise") | ||
| with pytest.raises(ValueError): | ||
| _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp") | ||
|
|
||
| def test_is_fp8_tensorwise_scaling(self): | ||
| """ | ||
| Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling. | ||
| """ | ||
| assert is_fp8_tensorwise_scaling(None) | ||
| assert is_fp8_tensorwise_scaling("tensorwise") | ||
| assert not is_fp8_tensorwise_scaling("rowwise") | ||
| assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,18 +4,25 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Callable, Optional | ||
| from typing import Callable, Dict, Optional | ||
|
|
||
| from torch import nn | ||
| from torch.distributed.tensor.parallel.style import ParallelStyle | ||
|
|
||
| from torchao.dtypes import TensorCoreTiledLayout | ||
|
|
||
| from torchao.float8 import ( | ||
| convert_to_float8_training as _convert_to_float8_training_torchao, | ||
| Float8LinearConfig, | ||
| ) | ||
| from torchao.float8.float8_tensor_parallel import ( | ||
| Float8ColwiseParallel, | ||
| Float8RowwiseParallel, | ||
| ) | ||
| from torchao.quantization import ( | ||
| int4_weight_only, | ||
| int8_dynamic_activation_int4_weight, | ||
| quantize_, | ||
| ) | ||
|
|
||
| from torchao.quantization.qat import ( | ||
| Int4WeightOnlyQATQuantizer, | ||
| Int8DynActInt4WeightQATQuantizer, | ||
|
|
@@ -26,6 +33,7 @@ | |
| enable_4w_fake_quant, | ||
| enable_8da4w_fake_quant, | ||
| ) | ||
|
|
||
| from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear | ||
|
|
||
|
|
||
|
|
@@ -219,3 +227,61 @@ def swap_lora_linear_with_qat( | |
| activation_qat_config, | ||
| weight_qat_config, | ||
| ) | ||
|
|
||
|
|
||
| def convert_to_float8_training( | ||
| model: nn.Module, | ||
| fp8_recipe_name: Optional[str] = None, | ||
| ) -> nn.Module: | ||
| """ | ||
| Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`. | ||
|
|
||
| Args: | ||
| model (nn.Module): The model to swap linear layers on | ||
| fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes, | ||
| one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Noob float8 q: seems like tensorwise gives the best perf in your runs, is that typically the case? I assume the rowwise scaling will have less degradation, though doesn't seem to show up in the eval numbers at all. (Mainly want to gain some understanding of what a sensible default is for folks to use.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is rowwise is typically a bit slower but more accurate. I think we don't see the "more accurate" part in the experiments because for this workload tensorwise is already preserving accuracies very well, but it's possible we'll need rowwise for other workloads still |
||
| defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See | ||
| https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150 | ||
| for more details. | ||
|
|
||
| Returns: | ||
| (nn.Module) The new model with `Float8Linear`. | ||
| """ | ||
| if fp8_recipe_name is not None: | ||
| fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name) | ||
| else: | ||
| fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True) | ||
andrewor14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return _convert_to_float8_training_torchao( | ||
| model, | ||
| config=fp8_config, | ||
| module_filter_fn=lambda mod, fqn: fqn != "output", | ||
andrewor14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
|
|
||
| # TODO: validate this in full_finetune_distributed recipe once FP8 + TP is enabled | ||
| def _validate_float8_tp_plan( | ||
| tp_plan: Optional[Dict[str, ParallelStyle]], | ||
| fp8_recipe_name: Optional[str] = None, | ||
| ) -> None: | ||
| """ | ||
| Validate that the provided tensor parallel plan is compatible with the | ||
| float8 settings. Specifically, float8 tensor parallel plans are only | ||
| supported when using 'tensorwise' float8 recipes. | ||
| """ | ||
| if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name): | ||
| return | ||
| for parallel_style in tp_plan.values(): | ||
| if isinstance(parallel_style, Float8ColwiseParallel) or isinstance( | ||
| parallel_style, Float8RowwiseParallel | ||
| ): | ||
| raise ValueError( | ||
| "%s and %s are only compatible with 'tensorwise' float8 recipes" | ||
| % (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__) | ||
| ) | ||
|
|
||
|
|
||
| def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]): | ||
| """ | ||
| Return True if the fp8 recipe name refers to 'tensorwwise' scaling. | ||
| """ | ||
| return fp8_recipe_name is None or fp8_recipe_name == "tensorwise" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo does this level of configuration look good to you? Any other fields you think I should expose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember you mentioned a performance regression in previous experiments, and we then root caused it to the fact that Float8Linear was applied on parts of the model where torch.compile was not applied to. I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo @andrewor14 From
torchaosource it appears rowwise scaling options do not supportforce_recompute_fp8_weight_in_bwdby default.Given we are defaulting
enable_fsdp_float8_all_gather=Truein the case where a recipe name isn't provided, should we also setforce_recompute_fp8_weight_in_bwdor should this be an option here in the config?I noticed a TODO in
torchaoto set this to True by default in future - I wonder if we should do that in tune from now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @vkuzo mentioned that
force_recompute_fp8_weight_in_bwdis no longer needed after a certain commit. This PR gates on the nightlies so we don't need this flag anymoreThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding @vkuzo's comment, torchtune does per-layer compile with separate compilation of the loss. You can see the utility we use here. This means that there may be two cases that are not covered:
Given the numbers in the test plan it seems (1) is not a hard blocker. We could try out with e.g. Llama 3.2 Vision just to make sure there are no regressions there. Otherwise I don't have any huge concerns here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Yes we already do (1) in this PR, will try Llama 3.2 Vision