Skip to content

Commit 08fa31e

Browse files
committed
[refactor] dynamically import TrainSpec
1 parent 85d92de commit 08fa31e

File tree

15 files changed

+87
-87
lines changed

15 files changed

+87
-87
lines changed

torchtitan/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,3 @@
66

77
# Import to register quantization modules.
88
import torchtitan.components.quantization # noqa: F401
9-
10-
# Import the built-in models here so that the corresponding register_model_spec()
11-
# will be called.
12-
import torchtitan.experiments # noqa: F401
13-
import torchtitan.models # noqa: F401

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torchtitan.experiments.llama4 # noqa: F401
8-
import torchtitan.experiments.qwen3
9-
import torchtitan.experiments.simple_fsdp # noqa: F401
10-
import torchtitan.experiments.vlm # noqa: F401
7+
_supported_experiments = ["flux", "llama4", "qwen3", "simple_fsdp", "vlm"]

torchtitan/experiments/flux/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
1111
from torchtitan.components.optimizer import build_optimizers
12-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
12+
from torchtitan.protocols.train_spec import TrainSpec
1313

1414
from .dataset.flux_dataset import build_flux_dataloader
1515
from .infra.parallelize import parallelize_flux
@@ -107,8 +107,8 @@
107107
}
108108

109109

110-
register_train_spec(
111-
TrainSpec(
110+
def get_train_spec() -> TrainSpec:
111+
return TrainSpec(
112112
name="flux",
113113
model_cls=FluxModel,
114114
model_args=flux_configs,
@@ -122,4 +122,3 @@
122122
build_validator_fn=build_flux_validator,
123123
state_dict_adapter=FluxStateDictAdapter,
124124
)
125-
)

torchtitan/experiments/forge/example_train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,11 @@ def train(self):
297297
break
298298

299299
# Run validation if validator is available
300-
if self.job_config.enable and self.validator.should_validate(self.step):
301-
self.validator.validate(self.model_parts)
300+
if (
301+
self.job_config.validation.enable
302+
and self.validator.should_validate(self.step)
303+
):
304+
self.validator.validate(self.model_parts, self.step)
302305

303306
self.checkpointer.save(
304307
self.step, last_step=(self.step == job_config.training.steps)

torchtitan/experiments/forge/train_spec.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8+
from importlib import import_module
9+
from typing import Mapping
810

9-
# Import torchtitan.models to ensure all train specs are registered
10-
import torchtitan.models # noqa: F401
1111
from torchtitan.protocols import BaseModelArgs, BaseStateDictAdapter, ModelProtocol
1212
from torchtitan.protocols.train_spec import (
13-
_train_specs,
1413
LossFunctionBuilder,
1514
LRSchedulersBuilder,
1615
OptimizersBuilder,
@@ -24,7 +23,7 @@
2423
class ForgeTrainSpec:
2524
name: str
2625
model_cls: type[ModelProtocol]
27-
model_args: dict[str, BaseModelArgs]
26+
model_args: Mapping[str, BaseModelArgs]
2827
parallelize_fn: ParallelizeFunction
2928
pipelining_fn: PipeliningFunction | None
3029
build_optimizers_fn: OptimizersBuilder
@@ -33,24 +32,7 @@ class ForgeTrainSpec:
3332
state_dict_adapter: type[BaseStateDictAdapter] | None = None
3433

3534

36-
# Copy and transform train specs from torchtitan.protocols.train_spec._train_specs
37-
# This happens during import after all models have been registered
38-
_forge_train_specs = {}
39-
40-
41-
def register_train_spec(train_spec: ForgeTrainSpec) -> None:
42-
global _forge_train_specs
43-
if train_spec.name in _forge_train_specs:
44-
raise ValueError(f"Model {train_spec.name} is already registered.")
45-
46-
_forge_train_specs[train_spec.name] = train_spec
47-
48-
49-
def get_train_spec(name: str) -> ForgeTrainSpec:
50-
global _forge_train_specs
51-
if name not in _forge_train_specs:
52-
raise ValueError(f"Model {name} is not registered.")
53-
return _forge_train_specs[name]
35+
_extra_train_specs = {}
5436

5537

5638
def _transform_train_spec(original_spec: TrainSpec):
@@ -69,6 +51,29 @@ def _transform_train_spec(original_spec: TrainSpec):
6951
)
7052

7153

72-
# Populate _forge_train_specs with transformed specs
73-
for name, spec in _train_specs.items():
74-
register_train_spec(_transform_train_spec(spec))
54+
def register_train_spec(train_spec: ForgeTrainSpec) -> None:
55+
global _extra_train_specs
56+
if train_spec.name in _extra_train_specs:
57+
raise ValueError(f"ForgeTrainSpec {train_spec.name} is already registered.")
58+
59+
# user can define a ForgeTrainSpec from outside of torchtitan
60+
_extra_train_specs[train_spec.name] = train_spec
61+
62+
63+
def get_train_spec(name: str) -> ForgeTrainSpec:
64+
# user-defined ForgeTrainSpec has higher priority
65+
global _extra_train_specs
66+
if name in _extra_train_specs:
67+
return _extra_train_specs[name]
68+
69+
from torchtitan.experiments import _supported_experiments
70+
from torchtitan.models import _supported_models
71+
72+
if name in _supported_models:
73+
module = import_module(f"torchtitan.models.{name}")
74+
return _transform_train_spec(module.get_train_spec())
75+
elif name in _supported_experiments:
76+
module = import_module(f"torchtitan.experiments.{name}")
77+
return _transform_train_spec(module.get_train_spec())
78+
79+
raise ValueError(f"ForgeTrainSpec {name} is not registered.")

torchtitan/experiments/llama4/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1212
from torchtitan.models.llama3 import pipeline_llama
1313
from torchtitan.models.moe import MoEArgs
14-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
14+
from torchtitan.protocols.train_spec import TrainSpec
1515

1616
from .infra.parallelize import parallelize_llama
1717
from .model.args import TransformerModelArgs
@@ -97,8 +97,8 @@
9797
}
9898

9999

100-
register_train_spec(
101-
TrainSpec(
100+
def get_train_spec() -> TrainSpec:
101+
return TrainSpec(
102102
name="llama4",
103103
model_cls=Transformer,
104104
model_args=llama4_configs,
@@ -111,4 +111,3 @@
111111
build_loss_fn=build_cross_entropy_loss,
112112
state_dict_adapter=Llama4StateDictAdapter,
113113
)
114-
)

torchtitan/experiments/simple_fsdp/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1414
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
15-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15+
from torchtitan.protocols.train_spec import TrainSpec
1616

1717
from .model import SimpleFSDPTransformer
1818
from .parallelize import parallelize_llama
1919

20-
register_train_spec(
21-
TrainSpec(
20+
21+
def get_train_spec() -> TrainSpec:
22+
return TrainSpec(
2223
name="llama3_simple_fsdp",
2324
model_cls=SimpleFSDPTransformer,
2425
model_args=llama3_configs,
@@ -30,4 +31,3 @@
3031
build_tokenizer_fn=build_hf_tokenizer,
3132
build_loss_fn=build_cross_entropy_loss,
3233
)
33-
)

torchtitan/experiments/vlm/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.components.validate import build_validator
1414
from torchtitan.models.llama3 import llama3_configs
15-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15+
from torchtitan.protocols.train_spec import TrainSpec
1616

1717
from .datasets.mm_datasets import build_mm_dataloader
1818
from .infra.parallelize import parallelize_vlm
@@ -40,8 +40,8 @@
4040
}
4141

4242

43-
register_train_spec(
44-
TrainSpec(
43+
def get_train_spec() -> TrainSpec:
44+
return TrainSpec(
4545
name="llama3-siglip2",
4646
model_cls=Llama3Siglip2Transformer,
4747
model_args=llama3_siglip2_configs,
@@ -54,4 +54,3 @@
5454
build_loss_fn=build_cross_entropy_loss,
5555
build_validator_fn=build_validator,
5656
)
57-
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
einops
2+
pillow

torchtitan/models/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ The folder should be organized as follows
3939
- Include other util files if necessary.
4040
- `__init__.py`
4141
- A dictionary of the actual model configurations, of the type `[str: ModelArgs]`.
42-
- Call `register_train_spec` to specify a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of
42+
- Define `get_train_spec` to return a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of
4343
- model name, model class, model args
44+
- Model name should be the same as the folder name, which should be added to `torchtitan/models/__init__.py` or ``torchtitan/experiments/__init__.py``.
4445
- parallelizing function, pipelining function
4546
- builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function
4647
- More often than not, existing components can be reused.
4748
- Adding new datasets requires the `torchtitan` team’s review and legal approval.
4849
- Try to have minimal dependency on external libraries, if any.
4950
- state dict adapter
51+
- If developing outside of torchtitan, one can call `register_train_spec` to register a `TrainSpec` so that `train.py` can be reused.
5052
- Read [more](/docs/extension.md#trainspec) on `TrainSpec`.
5153
- `README.md`
5254
- Include [instructions](/README.md#downloading-a-tokenizer) to download tokenizers / encoders.

0 commit comments

Comments
 (0)