Skip to content

Commit

Permalink
Update on "[BE][1/n] simplify train.py and rename things for consiste…
Browse files Browse the repository at this point in the history
…ncy"


Including:
- move TrainState from `train.py` to `checkpoint.py`
- create `optimizer.py` to include things related to optimizers and lr scheduler.
- rename `logging_utils.py` to `logging.py`
- unify various build prefixes (`build_`, `create_`, `get_`) to `build_`
- simplify utils import by doing `import torchtitan.utils as utils`
- move `get_metrics_rank` from `utils.py` to `metrics.py` to make `build_metric_logger` simpler
- create `GarbageCollection` in `utils.py` to hide gc details from `train.py`
- reorder definition and initialization of some objects in `train.py` to be closer to where they are first used
- expose `build_pipeline_schedule` to `torchtitan.parallelisms`
- other minor improvements to reduce the amount of import in `train.py`

After this refactoring, LoC for import in `train.py` drops from 51 to 23.

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Jul 31, 2024
1 parent 0566d80 commit e35f66c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@
import os

import torch
import torch.nn.functional as F
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers, get_train_context
from train import get_train_context


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -97,7 +95,7 @@ def estimate_memory(job_config: JobConfig):

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
Expand All @@ -106,7 +104,9 @@ def estimate_memory(job_config: JobConfig):

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
Expand Down Expand Up @@ -146,7 +146,7 @@ def loss_fn(pred, labels):

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
Expand Down Expand Up @@ -224,4 +224,4 @@ def loss_fn(pred, labels):
try:
estimate_memory(config)
finally:
destroy_process_group()
torch.distributed.destroy_process_group()

0 comments on commit e35f66c

Please sign in to comment.