Skip to content

Commit cabf5b4

Browse files
committed
Update
[ghstack-poisoned]
2 parents 15ed7ee + 79403b5 commit cabf5b4

File tree

13 files changed

+142
-22
lines changed

13 files changed

+142
-22
lines changed

.github/workflows/integration_test_4gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42+
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
4243
mkdir artifacts-to-be-uploaded
4344
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

estimation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def loss_fn(pred, labels):
126126
whole_model = model_cls.from_model_args(model_config)
127127

128128
# apply fp8 linear module swap
129-
if job_config.training.fp8_linear:
130-
build_fp8_linear(whole_model, job_config)
129+
if job_config.training.enable_fp8_linear:
130+
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
131131

132132
# apply PT-D DP/TP parallelisms and activation checkpointing
133133
model_parts = [whole_model]

test_runner.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,39 @@ def build_test_list():
273273
"fsdp2_mem_tracker",
274274
ngpu=4,
275275
),
276+
OverrideDefinitions(
277+
[
278+
[
279+
<<<<<<< HEAD
280+
"--training.enable_float8_linear",
281+
]
282+
],
283+
"FSDP2 with original dtype",
284+
"float8_fsdp2_orig_all_gather",
285+
ngpu=4,
286+
),
287+
OverrideDefinitions(
288+
[
289+
[
290+
"--training.enable_float8_linear",
291+
"--training.enable_fsdp_float8_all_gather",
292+
]
293+
],
294+
"FSDP2 with float8 all-gather",
295+
"fsdp2_float8_all_gather",
296+
ngpu=4,
297+
),
298+
OverrideDefinitions(
299+
[
300+
[
301+
"--training.enable_float8_linear",
302+
"--training.enable_fsdp_float8_all_gather",
303+
"--training.precompute_float8_dynamic_scale_for_fsdp",
304+
]
305+
],
306+
"FSDP2 with float8 all-gather and precomputed dynamic scales",
307+
"fsdp2_float8_all_gather_precompute_dynamic_scales",
308+
),
276309
OverrideDefinitions(
277310
[
278311
[

torchtitan/config_manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def __init__(self):
355355
help="Whether to compile the model",
356356
)
357357
self.parser.add_argument(
358-
"--training.fp8_linear",
358+
"--training.enable_float8_linear",
359359
action="store_true",
360360
help="""
361361
If true, swaps `torch.nn.Linear` with `Float8Linear` with
@@ -364,6 +364,18 @@ def __init__(self):
364364
here: https://github.com/pytorch-labs/float8_experimental
365365
""",
366366
)
367+
self.parser.add_argument(
368+
"--training.enable_fsdp_float8_all_gather",
369+
action="store_true",
370+
default=False,
371+
help="Whether enable float8 all-gather in FSDP",
372+
)
373+
self.parser.add_argument(
374+
"--training.precompute_float8_dynamic_scale_for_fsdp",
375+
action="store_true",
376+
default=False,
377+
help="Whether precompute float8 scales dynamically for FSDP",
378+
)
367379
self.parser.add_argument(
368380
"--training.gc_freq",
369381
type=int,

torchtitan/float8_linear.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,94 @@
1212

1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15+
import contextlib
16+
import functools
17+
from typing import Optional
1518

19+
import torch
1620
import torch.nn as nn
21+
from torch._logging import warning_once
1722

1823
from torchtitan.config_manager import JobConfig
1924
from torchtitan.logging_utils import logger
2025

2126

22-
def build_fp8_linear(model: nn.Module, job_config: JobConfig):
27+
@contextlib.contextmanager
28+
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
29+
import float8_experimental.config as config
30+
31+
prev = config.enable_fsdp_fp8_all_gather
32+
torch.distributed.barrier()
33+
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
34+
try:
35+
yield
36+
finally:
37+
torch.distributed.barrier()
38+
config.enable_fsdp_fp8_all_gather = prev
39+
40+
41+
@functools.lru_cache(None)
42+
def is_sm90_or_later():
43+
# Float8 is only supported on H100+ GPUs
44+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
45+
46+
47+
def maybe_build_fp8_linear(
48+
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
49+
):
2350
"""
2451
This function converts the linear layers to `Float8Linear`. Note that today,
2552
only dynamic tensor scaling (the default) is supported.
2653
2754
This will mutate the model inplace.
2855
"""
29-
use_fp8_linear = job_config.training.fp8_linear
56+
enable_float8_linear = job_config.training.enable_float8_linear
57+
if not enable_float8_linear:
58+
return
59+
if not is_sm90_or_later():
60+
warning_once(
61+
logger,
62+
"Failed to swap to Float8Linear because SM90 or later is not available",
63+
)
64+
return
3065
try:
31-
from float8_experimental.float8_linear import Float8Linear
66+
from float8_experimental.float8_linear import TensorScalingType
3267
from float8_experimental.float8_linear_utils import (
3368
swap_linear_with_float8_linear,
3469
)
70+
71+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
72+
enable_fsdp_float8_all_gather = (
73+
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
74+
)
75+
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
76+
swap_linear_with_float8_linear(
77+
model, scaling_type_w=TensorScalingType.DYNAMIC
78+
)
79+
logger.info(
80+
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
81+
)
3582
except ImportError as exc:
3683
raise ImportError(
3784
"float8_experimental is not installed. Please install it to use fp8 linear layers."
3885
) from exc
39-
if use_fp8_linear:
40-
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
41-
swap_linear_with_float8_linear(model, Float8Linear)
42-
logger.info("Swapped to Float8Linear layers")
86+
87+
88+
def maybe_precompute_fp8_dynamic_scale_for_fsdp(
89+
model: nn.Module, job_config: JobConfig
90+
):
91+
if not (
92+
job_config.training.enable_float8_linear
93+
and job_config.training.enable_fsdp_float8_all_gather
94+
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
95+
):
96+
return
97+
if not is_sm90_or_later():
98+
warning_once(
99+
logger,
100+
"Skipped precomputing fp8 scales because SM90 or later is not available",
101+
)
102+
return
103+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
104+
105+
precompute_float8_dynamic_scale_for_fsdp(model)

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,16 @@ def selective_checkpointing_context_fn():
125125

126126
def get_tp_parallel_strategy(
127127
job_config: JobConfig,
128+
model: nn.Module,
128129
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
129130
"""Get the parallel strategy for the transformer model.
130131
131132
This function handles the special case of using float8 with tensor parallelism.
132133
"""
133-
if job_config.training.fp8_linear == "dynamic":
134+
if job_config.training.enable_float8_linear:
135+
# TODO(future PR): once float8 configuration supports delayed
136+
# scaling, add a check here to enforce supported float8 all-gather
137+
# configurations
134138
from float8_experimental.float8_tensor_parallel import (
135139
Float8ColwiseParallel,
136140
Float8RowwiseParallel,
@@ -354,7 +358,7 @@ def apply_tp(
354358
rowwise_parallel_weight,
355359
colwise_parallel_weight,
356360
prepare_module_input,
357-
) = get_tp_parallel_strategy(job_config)
361+
) = get_tp_parallel_strategy(job_config, model)
358362
loss_parallel = parallel_dims.loss_parallel_enabled
359363

360364
# 1. Parallelize the embedding and shard its outputs (which are the first

train.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
from torchtitan.checkpoint import CheckpointManager
3030
from torchtitan.config_manager import JobConfig
3131
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
32-
from torchtitan.float8_linear import build_fp8_linear
32+
from torchtitan.float8_linear import (
33+
maybe_build_fp8_linear,
34+
maybe_precompute_fp8_dynamic_scale_for_fsdp,
35+
)
3336
from torchtitan.logging_utils import init_logger, logger
3437
from torchtitan.lr_scheduling import get_lr_schedulers
3538
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
@@ -249,9 +252,8 @@ def loss_fn(pred, labels):
249252
with torch.device("meta"):
250253
whole_model = model_cls.from_model_args(model_config)
251254

252-
# apply fp8 linear module swap
253-
if job_config.training.fp8_linear:
254-
build_fp8_linear(whole_model, job_config)
255+
# swap to Float8Linear base on fp8 config
256+
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
255257

256258
# log model size
257259
model_param_count = get_num_params(whole_model)
@@ -437,6 +439,11 @@ def loss_fn(pred, labels):
437439
optimizers.step()
438440
lr_schedulers.step()
439441

442+
# when fp8 config is on,
443+
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
444+
# it issues a single all-reduce for all parameters at once for better performance
445+
maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config)
446+
440447
losses_since_last_log.append(loss)
441448

442449
# log metrics

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 10
3838
data_parallel_degree = -1
3939
tensor_parallel_degree = 1
40-
fp8_linear = false
40+
enable_float8_linear = false
4141
compile = false
4242
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
4343

train_configs/llama2_13b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
fp8_linear = false
36+
enable_float8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

0 commit comments

Comments
 (0)