Skip to content

Commit 4a2de42

Browse files
authored
clean up float8 configs in torchtitan (#466)
Summary: 1. standardizes on `float8` instead of `fp8` for config names 2. removes usage of non-public objects such as `Float8Linear` Test Plan: ``` with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 183390e commit 4a2de42

File tree

10 files changed

+33
-41
lines changed

10 files changed

+33
-41
lines changed

test_runner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,34 +276,34 @@ def build_test_list():
276276
OverrideDefinitions(
277277
[
278278
[
279-
"--training.enable_fp8_linear",
279+
"--training.enable_float8_linear",
280280
]
281281
],
282282
"FSDP2 with original dtype",
283-
"fp8_fsdp2_orig_all_gather",
283+
"float8_fsdp2_orig_all_gather",
284284
ngpu=4,
285285
),
286286
OverrideDefinitions(
287287
[
288288
[
289-
"--training.enable_fp8_linear",
290-
"--training.enable_fsdp_fp8_all_gather",
289+
"--training.enable_float8_linear",
290+
"--training.enable_fsdp_float8_all_gather",
291291
]
292292
],
293-
"FSDP2 with fp8 all-gather",
294-
"fsdp2_fp8_all_gather",
293+
"FSDP2 with float8 all-gather",
294+
"fsdp2_float8_all_gather",
295295
ngpu=4,
296296
),
297297
OverrideDefinitions(
298298
[
299299
[
300-
"--training.enable_fp8_linear",
301-
"--training.enable_fsdp_fp8_all_gather",
300+
"--training.enable_float8_linear",
301+
"--training.enable_fsdp_float8_all_gather",
302302
"--training.precompute_float8_dynamic_scale_for_fsdp",
303303
]
304304
],
305-
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
306-
"fsdp2_fp8_all_gather_precompute_dynamic_scales",
305+
"FSDP2 with float8 all-gather and precomputed dynamic scales",
306+
"fsdp2_float8_all_gather_precompute_dynamic_scales",
307307
ngpu=4,
308308
),
309309
]

torchtitan/config_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def __init__(self):
338338
help="Whether to compile the model",
339339
)
340340
self.parser.add_argument(
341-
"--training.enable_fp8_linear",
341+
"--training.enable_float8_linear",
342342
action="store_true",
343343
help="""
344344
If true, swaps `torch.nn.Linear` with `Float8Linear` with
@@ -348,16 +348,16 @@ def __init__(self):
348348
""",
349349
)
350350
self.parser.add_argument(
351-
"--training.enable_fsdp_fp8_all_gather",
351+
"--training.enable_fsdp_float8_all_gather",
352352
action="store_true",
353353
default=False,
354-
help="Whether enable fp8 all-gather in FSDP",
354+
help="Whether enable float8 all-gather in FSDP",
355355
)
356356
self.parser.add_argument(
357357
"--training.precompute_float8_dynamic_scale_for_fsdp",
358358
action="store_true",
359359
default=False,
360-
help="Whether precompute fp8 scales dynamically for FSDP",
360+
help="Whether precompute float8 scales dynamically for FSDP",
361361
)
362362
self.parser.add_argument(
363363
"--training.gc_freq",

torchtitan/float8_linear.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
@contextlib.contextmanager
28-
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
28+
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
2929
import float8_experimental.config as config
3030

3131
prev = config.enable_fsdp_fp8_all_gather
@@ -53,8 +53,8 @@ def maybe_build_fp8_linear(
5353
5454
This will mutate the model inplace.
5555
"""
56-
enable_fp8_linear = job_config.training.enable_fp8_linear
57-
if not enable_fp8_linear:
56+
enable_float8_linear = job_config.training.enable_float8_linear
57+
if not enable_float8_linear:
5858
return
5959
if not is_sm90_or_later():
6060
warning_once(
@@ -69,15 +69,15 @@ def maybe_build_fp8_linear(
6969
)
7070

7171
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
72-
enable_fsdp_fp8_all_gather = (
73-
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
72+
enable_fsdp_float8_all_gather = (
73+
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
7474
)
75-
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
75+
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
7676
swap_linear_with_float8_linear(
7777
model, scaling_type_w=TensorScalingType.DYNAMIC
7878
)
7979
logger.info(
80-
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}"
80+
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
8181
)
8282
except ImportError as exc:
8383
raise ImportError(
@@ -89,8 +89,8 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
8989
model: nn.Module, job_config: JobConfig
9090
):
9191
if not (
92-
job_config.training.enable_fp8_linear
93-
and job_config.training.enable_fsdp_fp8_all_gather
92+
job_config.training.enable_float8_linear
93+
and job_config.training.enable_fsdp_float8_all_gather
9494
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
9595
):
9696
return

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,10 @@ def get_tp_parallel_strategy(
123123
124124
This function handles the special case of using float8 with tensor parallelism.
125125
"""
126-
if job_config.training.enable_fp8_linear:
127-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
128-
129-
if any(
130-
isinstance(m, Float8Linear)
131-
and m.scaling_type_w is TensorScalingType.DELAYED
132-
for m in model.modules()
133-
):
134-
raise NotImplementedError(
135-
"1D TP fp8 all-gather only supports dynamic scaling"
136-
)
137-
126+
if job_config.training.enable_float8_linear:
127+
# TODO(future PR): once float8 configuration supports delayed
128+
# scaling, add a check here to enforce supported float8 all-gather
129+
# configurations
138130
from float8_experimental.float8_tensor_parallel import (
139131
Float8ColwiseParallel,
140132
Float8RowwiseParallel,

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 10
3838
data_parallel_degree = -1
3939
tensor_parallel_degree = 1
40-
enable_fp8_linear = false
40+
enable_float8_linear = false
4141
compile = false
4242
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
4343

train_configs/llama2_13b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
enable_fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
enable_fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_7b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
3232
steps = 1000
3333
data_parallel_degree = -1
3434
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
35-
enable_fp8_linear = false
35+
enable_float8_linear = false
3636
compile = false
3737
dataset = "c4"
3838

train_configs/llama3_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
enable_fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama3_8b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
enable_fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

0 commit comments

Comments
 (0)