From e35f66cdc700195a02d455c1fa4c98625bc62745 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jul 2024 23:52:12 -0700 Subject: [PATCH] Update on "[BE][1/n] simplify train.py and rename things for consistency" Including: - move TrainState from `train.py` to `checkpoint.py` - create `optimizer.py` to include things related to optimizers and lr scheduler. - rename `logging_utils.py` to `logging.py` - unify various build prefixes (`build_`, `create_`, `get_`) to `build_` - simplify utils import by doing `import torchtitan.utils as utils` - move `get_metrics_rank` from `utils.py` to `metrics.py` to make `build_metric_logger` simpler - create `GarbageCollection` in `utils.py` to hide gc details from `train.py` - reorder definition and initialization of some objects in `train.py` to be closer to where they are first used - expose `build_pipeline_schedule` to `torchtitan.parallelisms` - other minor improvements to reduce the amount of import in `train.py` After this refactoring, LoC for import in `train.py` drops from 51 to 23. [ghstack-poisoned] --- estimation.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/estimation.py b/estimation.py index f5527f74..3adcf663 100644 --- a/estimation.py +++ b/estimation.py @@ -9,24 +9,22 @@ import os import torch -import torch.nn.functional as F from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed import destroy_process_group from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore from torchtitan.config_manager import JobConfig -from torchtitan.datasets import create_tokenizer +from torchtitan.datasets import build_tokenizer from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, ) -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import build_optimizers, get_train_context +from train import get_train_context def estimate_memory(job_config: JobConfig): @@ -97,7 +95,7 @@ def estimate_memory(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) train_context = get_train_context( parallel_dims.loss_parallel_enabled, @@ -106,7 +104,9 @@ def estimate_memory(job_config: JobConfig): # loss fn can be shared by pipeline-parallel or non-pp execution def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) # build model (using meta init) model_cls = model_name_to_cls[model_name] @@ -146,7 +146,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) for model in model_parts: model.train() @@ -224,4 +224,4 @@ def loss_fn(pred, labels): try: estimate_memory(config) finally: - destroy_process_group() + torch.distributed.destroy_process_group()