Skip to content

Commit 624e014

Browse files
committed
Update
[ghstack-poisoned]
2 parents cabf5b4 + 4820195 commit 624e014

File tree

11 files changed

+258
-161
lines changed

11 files changed

+258
-161
lines changed

.github/workflows/integration_test_4gpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42-
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
42+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
4343
mkdir artifacts-to-be-uploaded
4444
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

estimation.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
from torch._subclasses.fake_tensor import FakeTensorMode
1515
from torch.distributed import destroy_process_group
1616
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17-
from torch.distributed.tensor.parallel import loss_parallel
1817
from torch.testing._internal.distributed.fake_pg import FakeStore
1918

2019
from torchtitan.config_manager import JobConfig
2120
from torchtitan.datasets import create_tokenizer
22-
from torchtitan.float8_linear import build_fp8_linear
21+
from torchtitan.float8_linear import (
22+
maybe_build_fp8_linear,
23+
maybe_precompute_fp8_dynamic_scale_for_fsdp,
24+
)
2325
from torchtitan.logging_utils import init_logger, logger
2426
from torchtitan.lr_scheduling import get_lr_schedulers
2527
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2628
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
27-
from train import build_optimizers
29+
from train import build_optimizers, get_train_context
2830

2931

3032
def estimate_memory(job_config: JobConfig):
@@ -61,9 +63,10 @@ def estimate_memory(job_config: JobConfig):
6163
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
6264
job_config.model.norm_type = "rmsnorm"
6365

64-
if job_config.training.compile:
66+
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
6567
logger.info("Compile mode is not supported yet. Switching to eager mode.")
6668
job_config.training.compile = False
69+
job_config.experimental.enable_compiled_autograd = False
6770

6871
parallel_dims = ParallelDims(
6972
dp=job_config.training.data_parallel_degree,
@@ -97,9 +100,9 @@ def estimate_memory(job_config: JobConfig):
97100
tokenizer_type = model_name_to_tokenizer[model_name]
98101
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
99102

100-
# loss_parallel enables dispatching to efficient loss operators
101-
loss_parallel_ctx = (
102-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
103+
train_context = get_train_context(
104+
parallel_dims.loss_parallel_enabled,
105+
job_config.experimental.enable_compiled_autograd,
103106
)
104107

105108
# loss fn can be shared by pipeline-parallel or non-pp execution
@@ -125,9 +128,8 @@ def loss_fn(pred, labels):
125128
with torch.device("meta"):
126129
whole_model = model_cls.from_model_args(model_config)
127130

128-
# apply fp8 linear module swap
129-
if job_config.training.enable_fp8_linear:
130-
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
131+
# swap to Float8Linear base on fp8 config
132+
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
131133

132134
# apply PT-D DP/TP parallelisms and activation checkpointing
133135
model_parts = [whole_model]
@@ -172,7 +174,7 @@ def loss_fn(pred, labels):
172174
for iter_idx in range(2):
173175
input_ids, labels = batch
174176
# train step
175-
with loss_parallel_ctx():
177+
with train_context():
176178
pred = whole_model(input_ids)
177179
loss = loss_fn(pred, labels)
178180
del pred
@@ -186,6 +188,10 @@ def loss_fn(pred, labels):
186188
# optimizer step
187189
optimizers.step()
188190
lr_schedulers.step()
191+
# when fp8 config is on,
192+
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
193+
# it issues a single all-reduce for all parameters at once for better performance
194+
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
189195
optimizers.zero_grad()
190196
print(f"Peak Memory at iter: {iter_idx}")
191197
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)

test_runner.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def build_test_list():
4646
"""
4747
integration_tests_flavors = defaultdict(list)
4848
integration_tests_flavors["debug_model.toml"] = [
49+
OverrideDefinitions(
50+
[
51+
[
52+
"--checkpoint.enable_checkpoint",
53+
"--experimental.pipeline_parallel_degree 4",
54+
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
55+
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
56+
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
57+
],
58+
],
59+
"PP looped flexible 1f1b test",
60+
"pp_looped_flexible_1f1b",
61+
requires_seed_checkpoint=True,
62+
ngpu=4,
63+
),
4964
OverrideDefinitions(
5065
[
5166
[
@@ -273,39 +288,6 @@ def build_test_list():
273288
"fsdp2_mem_tracker",
274289
ngpu=4,
275290
),
276-
OverrideDefinitions(
277-
[
278-
[
279-
<<<<<<< HEAD
280-
"--training.enable_float8_linear",
281-
]
282-
],
283-
"FSDP2 with original dtype",
284-
"float8_fsdp2_orig_all_gather",
285-
ngpu=4,
286-
),
287-
OverrideDefinitions(
288-
[
289-
[
290-
"--training.enable_float8_linear",
291-
"--training.enable_fsdp_float8_all_gather",
292-
]
293-
],
294-
"FSDP2 with float8 all-gather",
295-
"fsdp2_float8_all_gather",
296-
ngpu=4,
297-
),
298-
OverrideDefinitions(
299-
[
300-
[
301-
"--training.enable_float8_linear",
302-
"--training.enable_fsdp_float8_all_gather",
303-
"--training.precompute_float8_dynamic_scale_for_fsdp",
304-
]
305-
],
306-
"FSDP2 with float8 all-gather and precomputed dynamic scales",
307-
"fsdp2_float8_all_gather_precompute_dynamic_scales",
308-
),
309291
OverrideDefinitions(
310292
[
311293
[
@@ -347,6 +329,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
347329

348330
for override_arg in test_flavor.override_args:
349331
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
332+
if test_name == "fsdp2_mem_tracker":
333+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
350334
cmd += " " + dump_folder_arg
351335
cmd += " " + model_flavor_arg
352336
if override_arg:

torchtitan/config_manager.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(self):
275275
self.parser.add_argument(
276276
"--experimental.pipeline_parallel_schedule",
277277
type=str,
278-
choices=["1f1b", "gpipe", "interleaved_1f1b"],
278+
choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"],
279279
default="1f1b",
280280
help="""
281281
Specify the Pipeline Parallel schedule to use.
@@ -358,10 +358,9 @@ def __init__(self):
358358
"--training.enable_float8_linear",
359359
action="store_true",
360360
help="""
361-
If true, swaps `torch.nn.Linear` with `Float8Linear` with
362-
default settings (dynamic scaling).
363-
This feature requires you to install 'float8_experimental' which can be found
364-
here: https://github.com/pytorch-labs/float8_experimental
361+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
362+
This feature requires you to install 'torchao' which can be found
363+
here: https://github.com/pytorch/ao
365364
""",
366365
)
367366
self.parser.add_argument(
@@ -376,6 +375,25 @@ def __init__(self):
376375
default=False,
377376
help="Whether precompute float8 scales dynamically for FSDP",
378377
)
378+
self.parser.add_argument(
379+
"--training.float8_scaling_type_input",
380+
type=str,
381+
default="dynamic",
382+
help="float8 scaling for input, dynamic (default) or delayed",
383+
choices=["dynamic", "delayed"],
384+
)
385+
self.parser.add_argument(
386+
"--training.float8_scaling_type_weight",
387+
type=str,
388+
default="dynamic",
389+
help="float8 scaling for input, dynamic (default) or delayed",
390+
)
391+
self.parser.add_argument(
392+
"--training.float8_scaling_type_grad_output",
393+
type=str,
394+
default="dynamic",
395+
help="float8 scaling for input, dynamic (default) or delayed",
396+
)
379397
self.parser.add_argument(
380398
"--training.gc_freq",
381399
type=int,

torchtitan/float8_linear.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# [Note] Getting the 'float8_experimental' package:
8-
# This script requires the 'float8_experimental' package to function correctly.
7+
# [Note] Getting the 'torchao' package:
8+
# This script requires the 'torchao' package to function correctly.
99
# Please ensure you have this package installed from the appropriate repository.
10-
# You can obtain it from https://github.com/pytorch-labs/float8_experimental.
11-
# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git`
10+
# You can obtain it from https://github.com/pytorch/ao by following the
11+
# installation instructions.
1212

1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15-
import contextlib
1615
import functools
1716
from typing import Optional
1817

@@ -24,20 +23,6 @@
2423
from torchtitan.logging_utils import logger
2524

2625

27-
@contextlib.contextmanager
28-
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
29-
import float8_experimental.config as config
30-
31-
prev = config.enable_fsdp_fp8_all_gather
32-
torch.distributed.barrier()
33-
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
34-
try:
35-
yield
36-
finally:
37-
torch.distributed.barrier()
38-
config.enable_fsdp_fp8_all_gather = prev
39-
40-
4126
@functools.lru_cache(None)
4227
def is_sm90_or_later():
4328
# Float8 is only supported on H100+ GPUs
@@ -63,25 +48,42 @@ def maybe_build_fp8_linear(
6348
)
6449
return
6550
try:
66-
from float8_experimental.float8_linear import TensorScalingType
67-
from float8_experimental.float8_linear_utils import (
68-
swap_linear_with_float8_linear,
51+
from torchao.float8 import (
52+
CastConfig,
53+
convert_to_float8_training,
54+
Float8LinearConfig,
55+
ScalingType,
6956
)
7057

7158
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
7259
enable_fsdp_float8_all_gather = (
7360
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
7461
)
75-
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
76-
swap_linear_with_float8_linear(
77-
model, scaling_type_w=TensorScalingType.DYNAMIC
78-
)
62+
scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input)
63+
scaling_type_weight = ScalingType(
64+
job_config.training.float8_scaling_type_weight
65+
)
66+
scaling_type_grad_output = ScalingType(
67+
job_config.training.float8_scaling_type_grad_output
68+
)
69+
float8_config = Float8LinearConfig(
70+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
71+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
72+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
73+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
74+
enable_pre_and_post_forward=False,
75+
)
76+
convert_to_float8_training(
77+
model,
78+
config=float8_config,
79+
module_filter_fn=lambda mod, fqn: fqn != "output",
80+
)
7981
logger.info(
8082
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
8183
)
8284
except ImportError as exc:
8385
raise ImportError(
84-
"float8_experimental is not installed. Please install it to use fp8 linear layers."
86+
"torchao is not installed. Please install it to use fp8 linear layers."
8587
) from exc
8688

8789

@@ -100,6 +102,37 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
100102
"Skipped precomputing fp8 scales because SM90 or later is not available",
101103
)
102104
return
103-
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
105+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
104106

105107
precompute_float8_dynamic_scale_for_fsdp(model)
108+
109+
110+
_sync_float8_amax_and_scale_history = None
111+
112+
113+
def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig):
114+
if not (
115+
job_config.training.enable_float8_linear
116+
and (
117+
job_config.training.float8_scaling_type_input == "delayed"
118+
or job_config.training.float8_scaling_type_weight == "delayed"
119+
or job_config.training.float8_scaling_type_grad_output == "delayed"
120+
)
121+
):
122+
return
123+
124+
from torchao.float8 import sync_float8_amax_and_scale_history
125+
126+
# TODO(future): see if precalculating the modules to sync over is going to
127+
# meaningfully help performance
128+
129+
global _sync_float8_amax_and_scale_history
130+
if _sync_float8_amax_and_scale_history is None:
131+
if job_config.training.compile:
132+
_sync_float8_amax_and_scale_history = torch.compile(
133+
sync_float8_amax_and_scale_history
134+
)
135+
else:
136+
_sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
137+
138+
sync_float8_amax_and_scale_history(model)

torchtitan/lr_scheduling.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,43 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import functools
8+
79
from torch.optim.lr_scheduler import LambdaLR
810
from torchtitan.config_manager import JobConfig
911

10-
# global states for scheduling
11-
# these are needed as LambdaLR does not support argument passing
12-
_warmup_steps = 200
13-
_decay_steps = 0
14-
1512

16-
def linear_warmup_linear_decay(current_step: int) -> float:
13+
def linear_warmup_linear_decay(
14+
warmup_steps: int, decay_steps: int, current_step: int
15+
) -> float:
1716
"""Computes linear warmup followed by linear decay.
1817
Per LambdaLR requirement, this is accomplished by returning
1918
a multiplicative factor to adjust the learning rate to
2019
create the desired schedule.
2120
"""
22-
if current_step < _warmup_steps:
21+
if current_step < warmup_steps:
2322
# linear warmup
2423
# 0-indexed step, hence + 1 adjustments
2524
current_step += 1
26-
curr_adjustment = float(current_step / (_warmup_steps + 1))
25+
curr_adjustment = float(current_step / (warmup_steps + 1))
2726

2827
else:
2928
# linear decay
30-
normalized_step = _decay_steps - (current_step - _warmup_steps)
31-
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps
29+
normalized_step = decay_steps - (current_step - warmup_steps)
30+
curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps
3231

3332
return curr_adjustment
3433

3534

3635
def get_lr_schedulers(optimizers, job_config: JobConfig):
3736
def _get_lr_scheduler(optimizer):
3837
"""Build a linear warmup and linear decay scheduler"""
39-
global _warmup_steps, _decay_steps
40-
_warmup_steps = int(job_config.training.warmup_steps)
41-
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))
42-
43-
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
38+
warmup_steps = int(job_config.training.warmup_steps)
39+
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
40+
lr_lambda = functools.partial(
41+
linear_warmup_linear_decay, warmup_steps, decay_steps
42+
)
43+
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
4444
return warmup_scheduler
4545

4646
class SchedulersContainer:

0 commit comments

Comments
 (0)