diff --git a/estimation.py b/estimation.py index f5527f74..3adcf663 100644 --- a/estimation.py +++ b/estimation.py @@ -9,24 +9,22 @@ import os import torch -import torch.nn.functional as F from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed import destroy_process_group from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore from torchtitan.config_manager import JobConfig -from torchtitan.datasets import create_tokenizer +from torchtitan.datasets import build_tokenizer from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, ) -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import build_optimizers, get_train_context +from train import get_train_context def estimate_memory(job_config: JobConfig): @@ -97,7 +95,7 @@ def estimate_memory(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) train_context = get_train_context( parallel_dims.loss_parallel_enabled, @@ -106,7 +104,9 @@ def estimate_memory(job_config: JobConfig): # loss fn can be shared by pipeline-parallel or non-pp execution def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) # build model (using meta init) model_cls = model_name_to_cls[model_name] @@ -146,7 +146,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) for model in model_parts: model.train() @@ -224,4 +224,4 @@ def loss_fn(pred, labels): try: estimate_memory(config) finally: - destroy_process_group() + torch.distributed.destroy_process_group()