Skip to content

Commit

Permalink
clean up float8 configs in torchtitan (pytorch#466)
Browse files Browse the repository at this point in the history
Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Jul 17, 2024
1 parent 3760bcf commit 69fe8de
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 41 deletions.
20 changes: 10 additions & 10 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,34 +276,34 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_float8_linear",
]
],
"FSDP2 with original dtype",
"fp8_fsdp2_orig_all_gather",
"float8_fsdp2_orig_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
"--training.enable_float8_linear",
"--training.enable_fsdp_float8_all_gather",
]
],
"FSDP2 with fp8 all-gather",
"fsdp2_fp8_all_gather",
"FSDP2 with float8 all-gather",
"fsdp2_float8_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
"--training.enable_float8_linear",
"--training.enable_fsdp_float8_all_gather",
"--training.precompute_float8_dynamic_scale_for_fsdp",
]
],
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
"fsdp2_fp8_all_gather_precompute_dynamic_scales",
"FSDP2 with float8 all-gather and precomputed dynamic scales",
"fsdp2_float8_all_gather_precompute_dynamic_scales",
ngpu=4,
),
]
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self):
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.enable_fp8_linear",
"--training.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear` with
Expand All @@ -348,16 +348,16 @@ def __init__(self):
""",
)
self.parser.add_argument(
"--training.enable_fsdp_fp8_all_gather",
"--training.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable fp8 all-gather in FSDP",
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute fp8 scales dynamically for FSDP",
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.gc_freq",
Expand Down
18 changes: 9 additions & 9 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
import float8_experimental.config as config

prev = config.enable_fsdp_fp8_all_gather
Expand Down Expand Up @@ -53,8 +53,8 @@ def maybe_build_fp8_linear(
This will mutate the model inplace.
"""
enable_fp8_linear = job_config.training.enable_fp8_linear
if not enable_fp8_linear:
enable_float8_linear = job_config.training.enable_float8_linear
if not enable_float8_linear:
return
if not is_sm90_or_later():
warning_once(
Expand All @@ -69,15 +69,15 @@ def maybe_build_fp8_linear(
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}"
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
)
except ImportError as exc:
raise ImportError(
Expand All @@ -89,8 +89,8 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
model: nn.Module, job_config: JobConfig
):
if not (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
job_config.training.enable_float8_linear
and job_config.training.enable_fsdp_float8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
return
Expand Down
16 changes: 4 additions & 12 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,10 @@ def get_tp_parallel_strategy(
This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.enable_fp8_linear:
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

if any(
isinstance(m, Float8Linear)
and m.scaling_type_w is TensorScalingType.DELAYED
for m in model.modules()
):
raise NotImplementedError(
"1D TP fp8 all-gather only supports dynamic scaling"
)

if job_config.training.enable_float8_linear:
# TODO(future PR): once float8 configuration supports delayed
# scaling, add a check here to enforce supported float8 all-gather
# configurations
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
enable_fp8_linear = false
enable_float8_linear = false
compile = false
dataset = "c4"

Expand Down

0 comments on commit 69fe8de

Please sign in to comment.