diff --git a/test_runner.py b/test_runner.py index 12085917..b17bc948 100755 --- a/test_runner.py +++ b/test_runner.py @@ -40,6 +40,21 @@ def build_test_list(args): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_1f1b_3stage/", + "--experimental.pipeline_parallel_degree 3", + "--experimental.pipeline_parallel_split_points layers.1, layers.2", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test 1f1b with 3 PP stages", + requires_seed_checkpoint=True, + ngpu=3, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 2393d92f..3c08a065 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -12,7 +12,7 @@ __all__ = ["Transformer"] llama2_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), + "debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16), "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), @@ -29,7 +29,7 @@ } llama3_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16, rope_theta=500000), + "debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16, rope_theta=500000), "8B": ModelArgs( dim=4096, n_layers=32, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 909cd8d3..c7466c5f 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -205,6 +205,12 @@ def pipeline_llama_manual( logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + # TODO, support this? or just guard against it inside the lib + if job_config.training.batch_size % parallel_dims.pp != 0: + raise ValueError( + f"batch_size {job_config.training.batch_size} not divisible by pp dim, currently unsupported" + ) + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the # layers of the model that map to this stage, not the whole model.