Skip to content

Commit aa000a3

Browse files
authored
refactor TrainSpec to remove the name field (#1850)
1 parent 98d904f commit aa000a3

File tree

19 files changed

+28
-37
lines changed

19 files changed

+28
-37
lines changed

docs/extension.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The extension points and protocols mentioned in this note are subject to change.
1414
The coarse level abstraction tries to hit a balance between flexible component swapping and a straightforward train script ([train.py](../torchtitan/train.py)).
1515
Note that among all training components, currently [`CheckpointManager`](../torchtitan/components/checkpoint.py) and [`FTManager`](../torchtitan/components/ft.py) are not configurable since we do not expect them to be customized, but we are open to requests.
1616

17-
To register a `TrainSpec`, please follow the example of [Llama 3.1](../torchtitan/models/llama3/__init__.py) to `register_train_spec`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py).
17+
To register a `TrainSpec`, please use the `register_train_spec` API, and make sure registration happens before `get_train_spec` is called during training initialization. In torchtitan, `get_train_spec` will dynamically look for models in `torchtitan/models` or `torchtitan/experiments`.
1818

1919

2020
### `ModelConverter`

scripts/estimate/estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def estimate_memory(job_config: JobConfig):
9595
else contextlib.nullcontext()
9696
):
9797
logger.info(
98-
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
98+
f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}"
9999
)
100100
with torch.device("meta"):
101101
model = train_spec.model_cls(model_args)

tests/unit_tests/test_train_spec.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class TestTrainSpec:
7676
def test_register_train_spec(self):
7777
fake_config = {"fake": BaseModelArgs()}
7878
spec = TrainSpec(
79-
name="fake",
8079
model_cls=FakeModel,
8180
model_args=fake_config,
8281
parallelize_fn=parallelize_llama,
@@ -87,7 +86,7 @@ def test_register_train_spec(self):
8786
build_tokenizer_fn=build_hf_tokenizer,
8887
build_loss_fn=build_cross_entropy_loss,
8988
)
90-
register_train_spec(spec)
89+
register_train_spec("fake", spec)
9190
new_spec = get_train_spec("fake")
9291
assert new_spec == spec
9392

@@ -98,7 +97,6 @@ def test_optim_hook(self):
9897
fake_config = {"fake": BaseModelArgs()}
9998

10099
spec = TrainSpec(
101-
name="fake2",
102100
model_cls=FakeModel,
103101
model_args=fake_config,
104102
parallelize_fn=parallelize_llama,
@@ -109,7 +107,7 @@ def test_optim_hook(self):
109107
build_tokenizer_fn=build_hf_tokenizer,
110108
build_loss_fn=build_cross_entropy_loss,
111109
)
112-
register_train_spec(spec)
110+
register_train_spec("fake2", spec)
113111
new_spec = get_train_spec("fake2")
114112

115113
model = new_spec.model_cls(BaseModelArgs())

torchtitan/experiments/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040

4141

4242
register_train_spec(
43+
"deepseek3",
4344
TrainSpec(
44-
name="deepseek3",
4545
model_cls=DeepseekForCausalLM,
4646
model_args=deepseek_configs,
4747
parallelize_fn=parallelize_deepseek,
@@ -51,5 +51,5 @@
5151
build_dataloader_fn=build_hf_dataloader,
5252
build_tokenizer_fn=get_hf_tokenizer,
5353
build_loss_fn=build_cross_entropy_loss,
54-
)
54+
),
5555
)

torchtitan/experiments/flux/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@
109109

110110
def get_train_spec() -> TrainSpec:
111111
return TrainSpec(
112-
name="flux",
113112
model_cls=FluxModel,
114113
model_args=flux_configs,
115114
parallelize_fn=parallelize_flux,

torchtitan/experiments/forge/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(self, job_config: ForgeJobConfig):
167167
if parallel_dims.pp_enabled:
168168
if not self.train_spec.pipelining_fn:
169169
raise RuntimeError(
170-
f"Pipeline Parallel is enabled but {self.train_spec.name} "
170+
f"Pipeline Parallel is enabled but {job_config.model.name} "
171171
f"does not support pipelining"
172172
)
173173

torchtitan/experiments/forge/example_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, job_config: JobConfig):
6666

6767
model_args = self.model_args
6868
logger.info(
69-
f"Built {self.train_spec.name} {job_config.model.flavor} with {model_args}"
69+
f"Built {job_config.model.name} {job_config.model.flavor} with {model_args}"
7070
)
7171

7272
# metrics logging
@@ -78,7 +78,7 @@ def __init__(self, job_config: JobConfig):
7878
self.metrics_processor.num_flops_per_token = self.num_flops_per_token
7979

8080
logger.info(
81-
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
81+
f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} "
8282
f"{color.red}size: {self.model_param_count:,} total parameters{color.reset}"
8383
)
8484

torchtitan/experiments/forge/train_spec.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
@dataclass
2323
class ForgeTrainSpec:
24-
name: str
2524
model_cls: type[ModelProtocol]
2625
model_args: Mapping[str, BaseModelArgs]
2726
parallelize_fn: ParallelizeFunction
@@ -39,7 +38,6 @@ def _transform_train_spec(original_spec: TrainSpec):
3938
"""Transform the original train spec to ForgeTrainSpec format."""
4039
# Create a new TrainSpec with only the fields we need in forge
4140
return ForgeTrainSpec(
42-
name=original_spec.name,
4341
model_cls=original_spec.model_cls,
4442
model_args=original_spec.model_args,
4543
parallelize_fn=original_spec.parallelize_fn,
@@ -51,13 +49,13 @@ def _transform_train_spec(original_spec: TrainSpec):
5149
)
5250

5351

54-
def register_train_spec(train_spec: ForgeTrainSpec) -> None:
52+
def register_train_spec(name: str, train_spec: ForgeTrainSpec) -> None:
5553
global _extra_train_specs
56-
if train_spec.name in _extra_train_specs:
57-
raise ValueError(f"ForgeTrainSpec {train_spec.name} is already registered.")
54+
if name in _extra_train_specs:
55+
raise ValueError(f"ForgeTrainSpec {name} is already registered.")
5856

5957
# user can define a ForgeTrainSpec from outside of torchtitan
60-
_extra_train_specs[train_spec.name] = train_spec
58+
_extra_train_specs[name] = train_spec
6159

6260

6361
def get_train_spec(name: str) -> ForgeTrainSpec:

torchtitan/experiments/llama4/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
99
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
11+
from torchtitan.components.validate import build_validator
1112
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1213
from torchtitan.models.llama3 import pipeline_llama
1314
from torchtitan.models.moe import MoEArgs
@@ -103,7 +104,6 @@
103104

104105
def get_train_spec() -> TrainSpec:
105106
return TrainSpec(
106-
name="llama4",
107107
model_cls=Transformer,
108108
model_args=llama4_configs,
109109
parallelize_fn=parallelize_llama,
@@ -113,5 +113,6 @@ def get_train_spec() -> TrainSpec:
113113
build_dataloader_fn=build_hf_dataloader,
114114
build_tokenizer_fn=build_hf_tokenizer,
115115
build_loss_fn=build_cross_entropy_loss,
116+
build_validator_fn=build_validator,
116117
state_dict_adapter=Llama4StateDictAdapter,
117118
)

torchtitan/experiments/multimodal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
}
2323

2424
register_train_spec(
25+
"llama4_multimodal",
2526
TrainSpec(
26-
name="llama4_multimodal",
2727
model_cls=MultimodalDecoder,
2828
model_args=llama4_mm_configs,
2929
parallelize_fn=parallelize_llama,
@@ -33,5 +33,5 @@
3333
build_dataloader_fn=build_mm_dataloader,
3434
build_tokenizer_fn=build_hf_tokenizer,
3535
build_loss_fn=build_cross_entropy_loss,
36-
)
36+
),
3737
)

0 commit comments

Comments
 (0)