Skip to content

Commit be211c8

Browse files
authored
separate out diloco configs (#1516)
1 parent a9aa506 commit be211c8

File tree

6 files changed

+96
-61
lines changed

6 files changed

+96
-61
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.ft.config.job_config import FaultTolerance, JobConfig
8+
9+
10+
__all__ = [
11+
"FaultTolerance",
12+
"JobConfig",
13+
]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
from torchtitan.config.job_config import FaultTolerance as BaseFaultTolerance
10+
11+
12+
@dataclass
13+
class FaultTolerance(BaseFaultTolerance):
14+
"""
15+
Extends fault tolerance to also support Streaming DiLoCo
16+
"""
17+
18+
sync_steps: int = 5
19+
"""
20+
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
21+
is set.
22+
"""
23+
24+
should_quantize: bool = False
25+
"""
26+
Whether to quantize the gradients before allreduce.
27+
28+
Disabled by default since the quantization does utilize the GPU
29+
and uses more collectives. Enabling this requires knowing about
30+
the tradeoffs between GPU utilization and communication.
31+
32+
33+
This is only used when "semi_sync_method" is set.
34+
"""
35+
36+
fragment_sync_delay: int = 0
37+
"""
38+
Controls the number of inner steps to wait before blocking on a
39+
model fragment's synchronization. This is the "tao" parameter in
40+
the Streaming DiLoCo paper.
41+
42+
By default, each model fragment will be synced at the same step
43+
at which the allreduce is issued. Enabling delay can improve
44+
communication and computation overlap, but at the cost of compromising
45+
model quality
46+
47+
This is only used when "semi_sync_method" is set.
48+
"""
49+
50+
fragment_update_alpha: float = 0.0
51+
"""
52+
Determines how to mix the local and global optimized parameters
53+
54+
By default, we just use the global parameters. This ensures all
55+
DDP replicas have the same parameters after syncrhonizing on
56+
the fragment. Tuning this can also affect the model quality.
57+
58+
This is only used when "semi_sync_method" is set.
59+
"""
60+
61+
module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
62+
"""
63+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
64+
Each inner list represents one model fragment and contains the module names that belong to that fragment.
65+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
66+
will create 3 chunks: the first containing tok_embeddings and layers.0,
67+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
68+
"""
69+
70+
num_fragments: int = 1
71+
"""
72+
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
73+
This is used to automatically split the model into fragments provided that the model
74+
implements FaultTolerantTrainSpec
75+
"""
76+
77+
78+
@dataclass
79+
class JobConfig:
80+
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)

torchtitan/components/ft/diloco/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch.nn as nn
8-
from torchtitan.config.job_config import FaultTolerance as FTConfig
8+
from torchtitan.components.ft.config import FaultTolerance as FTConfig
99
from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part
1010

1111

torchtitan/components/ft/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn as nn
1616
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1717
from torch.distributed.distributed_c10d import ReduceOp
18-
from torchtitan.config.job_config import FaultTolerance as FTConfig
18+
from torchtitan.components.ft.config import FaultTolerance as FTConfig
1919

2020
if importlib.util.find_spec("torchft") is not None:
2121
import torchft as ft

torchtitan/config/job_config.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -653,65 +653,6 @@ class FaultTolerance:
653653
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
654654
"""
655655

656-
sync_steps: int = 5
657-
"""
658-
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
659-
is set.
660-
"""
661-
662-
should_quantize: bool = False
663-
"""
664-
Whether to quantize the gradients before allreduce.
665-
666-
Disabled by default since the quantization does utilize the GPU
667-
and uses more collectives. Enabling this requires knowing about
668-
the tradeoffs between GPU utilization and communication.
669-
670-
671-
This is only used when "semi_sync_method" is set.
672-
"""
673-
674-
fragment_sync_delay: int = 0
675-
"""
676-
Controls the number of inner steps to wait before blocking on a
677-
model fragment's synchronization. This is the "tao" parameter in
678-
the Streaming DiLoCo paper.
679-
680-
By default, each model fragment will be synced at the same step
681-
at which the allreduce is issued. Enabling delay can improve
682-
communication and computation overlap, but at the cost of compromising
683-
model quality
684-
685-
This is only used when "semi_sync_method" is set.
686-
"""
687-
688-
fragment_update_alpha: float = 0.0
689-
"""
690-
Determines how to mix the local and global optimized parameters
691-
692-
By default, we just use the global parameters. This ensures all
693-
DDP replicas have the same parameters after syncrhonizing on
694-
the fragment. Tuning this can also affect the model quality.
695-
696-
This is only used when "semi_sync_method" is set.
697-
"""
698-
699-
module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
700-
"""
701-
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
702-
Each inner list represents one model fragment and contains the module names that belong to that fragment.
703-
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
704-
will create 3 chunks: the first containing tok_embeddings and layers.0,
705-
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
706-
"""
707-
708-
num_fragments: int = 1
709-
"""
710-
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
711-
This is used to automatically split the model into fragments provided that the model
712-
implements FaultTolerantTrainSpec
713-
"""
714-
715656

716657
@dataclass
717658
class Experimental:

torchtitan/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
# will be called.
1010
import torchtitan.models.deepseek_v3 # noqa: F401
1111
import torchtitan.models.llama3 # noqa: F401
12+
import torchtitan.models.llama3_ft # noqa: F401

0 commit comments

Comments
 (0)