File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
experiments/auto_parallel Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ def parallelize_llama(
3232 NOTE: The passed-in model preferably should be on meta device. Otherwise,
3333 the model must fit on GPU or CPU memory.
3434 """
35+
3536 def input_fn ():
3637 global_batch_size = job_config .training .global_batch_size
3738 if global_batch_size < 0 :
Original file line number Diff line number Diff line change 1212
1313import torch
1414from torch .distributed .elastic .multiprocessing .errors import record
15+ from torch .distributed .tensor import DTensor
16+
1517import torchtitan .components .ft as ft
1618import torchtitan .protocols .train_spec as train_spec_module
1719from torchtitan .components .checkpoint import CheckpointManager
2325)
2426from torchtitan .config_manager import ConfigManager , JobConfig
2527from torchtitan .distributed import ParallelDims , utils as dist_utils
26- from torch .distributed .tensor import DTensor
2728from torchtitan .protocols .model_converter import build_model_converters
2829from torchtitan .tools import utils
2930from torchtitan .tools .logging import init_logger , logger
@@ -115,7 +116,7 @@ def __init__(self, job_config: JobConfig):
115116
116117 # TODO(whc)
117118 # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering
118- torch ._inductor .config .force_disable_caches = True
119+ torch ._inductor .config .force_disable_caches = True
119120
120121 # allow configuring inductor comms optimizations from torchtitan commandline
121122 torch ._inductor .config .reorder_for_compute_comm_overlap = (
You can’t perform that action at this time.
0 commit comments