Skip to content

Commit 4caa379

Browse files
authored
Deduplicate TorchTitan main function (#1995)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2002 * #2001 * __->__ #1995 People are creating different train.py and duplicate the `main` function. But in realitly people just want to use different Trainer subclasses. This PR creates a main() in torchtitan/train.py to deduplicate the code.
1 parent 157d30d commit 4caa379

File tree

4 files changed

+30
-93
lines changed

4 files changed

+30
-93
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import importlib
88
import time
99
from datetime import timedelta
10-
from typing import Any, Iterable, Optional
10+
from typing import Any, Iterable
1111

1212
import torch
1313
from torch.distributed.elastic.multiprocessing.errors import record
@@ -17,15 +17,16 @@
1717
from torchtitan.components.metrics import build_metrics_processor
1818
from torchtitan.components.tokenizer import build_hf_tokenizer
1919
from torchtitan.components.validate import build_validator
20-
from torchtitan.config import ConfigManager, JobConfig
20+
from torchtitan.config import JobConfig
2121
from torchtitan.distributed import utils as dist_utils
2222
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
2323
from torchtitan.tools import utils
24-
from torchtitan.tools.logging import init_logger, logger
24+
from torchtitan.tools.logging import logger
2525
from torchtitan.tools.profiling import (
2626
maybe_enable_memory_snapshot,
2727
maybe_enable_profiling,
2828
)
29+
from torchtitan.train import main
2930

3031
from .engine import ForgeEngine
3132

@@ -350,19 +351,4 @@ def close(self) -> None:
350351

351352

352353
if __name__ == "__main__":
353-
init_logger()
354-
config_manager = ConfigManager()
355-
config = config_manager.parse_args()
356-
trainer: Optional[Trainer] = None
357-
358-
try:
359-
trainer = Trainer(config)
360-
trainer.train()
361-
except Exception:
362-
if trainer:
363-
trainer.close()
364-
raise
365-
else:
366-
trainer.close()
367-
torch.distributed.destroy_process_group()
368-
logger.info("Process group destroyed.")
354+
main(Trainer)

torchtitan/experiments/torchcomms/train.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
8-
from typing import Optional
9-
10-
import torch
11-
12-
from torchtitan.config import ConfigManager
137
from torchtitan.distributed import ParallelDims
14-
from torchtitan.tools.logging import init_logger, logger
15-
from torchtitan.train import Trainer
8+
from torchtitan.train import main, Trainer
169

1710
from .parallel_dims import TorchCommsParallelDims
1811

@@ -32,35 +25,13 @@ def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
3225
world_size=world_size,
3326
)
3427

28+
def close(self) -> None:
29+
# Call finalize on all comms after training and before destroying process group.
30+
if hasattr(self, "parallel_dims"):
31+
for comm in self.parallel_dims.comms:
32+
comm.finalize()
33+
super().close()
3534

36-
if __name__ == "__main__":
37-
init_logger()
38-
config_manager = ConfigManager()
39-
config = config_manager.parse_args()
40-
trainer: Optional[TorchCommsTrainer] = None
41-
42-
try:
43-
trainer = TorchCommsTrainer(config)
4435

45-
if config.checkpoint.create_seed_checkpoint:
46-
assert (
47-
int(os.environ["WORLD_SIZE"]) == 1
48-
), "Must create seed checkpoint using a single device, to disable sharding."
49-
assert (
50-
config.checkpoint.enable
51-
), "Must enable checkpointing when creating a seed checkpoint."
52-
trainer.checkpointer.save(curr_step=0, last_step=True)
53-
logger.info("Created seed checkpoint")
54-
else:
55-
trainer.train()
56-
# Call finalize on all comms after training and before destroying process group.
57-
for comm in trainer.parallel_dims.comms:
58-
comm.finalize()
59-
except Exception:
60-
if trainer:
61-
trainer.close()
62-
raise
63-
else:
64-
trainer.close()
65-
torch.distributed.destroy_process_group()
66-
logger.info("Process group destroyed")
36+
if __name__ == "__main__":
37+
main(TorchCommsTrainer)

torchtitan/models/flux/train.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
8-
from typing import Optional
9-
107
import torch
118

12-
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
9+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1310
from torchtitan.distributed import utils as dist_utils
1411

1512
from torchtitan.models.flux.infra.parallelize import parallelize_encoders
@@ -20,8 +17,7 @@
2017
pack_latents,
2118
preprocess_data,
2219
)
23-
from torchtitan.tools.logging import init_logger, logger
24-
from torchtitan.train import Trainer
20+
from torchtitan.train import main, Trainer
2521

2622

2723
class FluxTrainer(Trainer):
@@ -175,29 +171,4 @@ def forward_backward_step(
175171

176172

177173
if __name__ == "__main__":
178-
init_logger()
179-
config_manager = ConfigManager()
180-
config = config_manager.parse_args()
181-
trainer: Optional[FluxTrainer] = None
182-
183-
try:
184-
trainer = FluxTrainer(config)
185-
if config.checkpoint.create_seed_checkpoint:
186-
assert (
187-
int(os.environ["WORLD_SIZE"]) == 1
188-
), "Must create seed checkpoint using a single device, to disable sharding."
189-
assert (
190-
config.checkpoint.enable
191-
), "Must enable checkpointing when creating a seed checkpoint."
192-
trainer.checkpointer.save(curr_step=0, last_step=True)
193-
logger.info("Created seed checkpoint")
194-
else:
195-
trainer.train()
196-
except Exception:
197-
if trainer:
198-
trainer.close()
199-
raise
200-
else:
201-
trainer.close()
202-
torch.distributed.destroy_process_group()
203-
logger.info("Process group destroyed.")
174+
main(FluxTrainer)

torchtitan/train.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11-
from typing import Any, Generator, Iterable, Optional
11+
from typing import Any, Generator, Iterable
1212

1313
import torch
1414

@@ -703,14 +703,19 @@ def close(self) -> None:
703703
self.metrics_processor.close()
704704

705705

706-
if __name__ == "__main__":
706+
def main(trainer_class: type[Trainer]) -> None:
707+
"""Main entry point for training with a specified trainer class.
708+
709+
Args:
710+
trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer)
711+
"""
707712
init_logger()
708713
config_manager = ConfigManager()
709714
config = config_manager.parse_args()
710-
trainer: Optional[Trainer] = None
715+
trainer: Trainer | None = None
711716

712717
try:
713-
trainer = Trainer(config)
718+
trainer = trainer_class(config)
714719

715720
if config.checkpoint.create_seed_checkpoint:
716721
assert (
@@ -731,3 +736,7 @@ def close(self) -> None:
731736
trainer.close()
732737
torch.distributed.destroy_process_group()
733738
logger.info("Process group destroyed")
739+
740+
741+
if __name__ == "__main__":
742+
main(Trainer)

0 commit comments

Comments
 (0)