Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,73 +52,78 @@ def test_job_config_file_cmd_overrides(self):
)
assert config.job.dump_folder == "/tmp/test_tt/"

def test_parse_pp_split_points(self):
toml_splits = ["layers.2", "layers.4", "layers.6"]
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
# no split points specified
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
"./torchtitan/models/llama3/train_configs/debug_model.toml",
]
)
assert config.parallelism.pipeline_parallel_split_points == []
def test_parse_module_names_per_model_chunk(self):
toml_chunks = [
["tok_embeddings", "layers.0"],
["layers.1", "layers.2"],
["layers.3", "norm", "output"],
]
cmdline_chunks = [
["tok_embeddings", "layers.0", "layers.1"],
["layers.2", "layers.3", "norm", "output"],
]

# toml has no split points, but cmdline splits are specified
# no module names specified
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
"./torchtitan/models/llama3/train_configs/debug_model.toml",
"--parallelism.pipeline_parallel_split_points",
",".join(cmdline_splits),
]
)
assert (
config.parallelism.pipeline_parallel_split_points == cmdline_splits
), config.parallelism.pipeline_parallel_split_points
assert config.parallelism.module_names_per_model_chunk == []

# toml has split points, cmdline does not
# toml has module names, cmdline does not
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"pipeline_parallel_split_points": toml_splits,
"module_names_per_model_chunk": toml_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.pipeline_parallel_split_points == toml_splits
), config.parallelism.pipeline_parallel_split_points
config.parallelism.module_names_per_model_chunk == toml_chunks
), config.parallelism.module_names_per_model_chunk

# toml has split points, cmdline overrides them
# test that the field accepts list of lists structure
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"pipeline_parallel_split_points": toml_splits,
"module_names_per_model_chunk": cmdline_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
fp.name,
"--parallelism.pipeline_parallel_split_points",
",".join(cmdline_splits),
]
)
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.module_names_per_model_chunk == cmdline_chunks
), config.parallelism.module_names_per_model_chunk

# test empty chunks are handled correctly
empty_chunks = [[], ["tok_embeddings"], []]
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"module_names_per_model_chunk": empty_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.pipeline_parallel_split_points == cmdline_splits
), config.parallelism.pipeline_parallel_split_points
config.parallelism.module_names_per_model_chunk == empty_chunks
), config.parallelism.module_names_per_model_chunk

def test_parse_exclude_from_loading(self):
toml_splits = ["optimizer", "dataloader"]
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class Parallelism:

pipeline_parallel_split_points: list[str] = field(default_factory=list)
"""
DEPRECATED: Use module_names_per_model_chunk instead.
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
Expand All @@ -321,6 +322,16 @@ class Parallelism:
but currently the split points must be specified manually.
"""

module_names_per_model_chunk: list[list[str]] = field(default_factory=list)
"""
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
Each inner list represents one model chunk and contains the module names that belong to that chunk.
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
will create 3 chunks: the first containing tok_embeddings and layers.0,
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
This provides more explicit control over which modules belong to each chunk compared to split points.
"""

pipeline_parallel_layers_per_stage: int | None = None
"""
The number of layers per (virtual) pipeline stage. If specified, the split points will be
Expand Down
Loading