diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index b82120a6..bb21293b 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,5 +1,5 @@ torch >= 2.2.0.dev -datasets +datasets >= 2.19.0 tomli >= 1.1.0 ; python_version < "3.11" tensorboard sentencepiece diff --git a/.github/workflows/integration_test_periodic.yaml b/.github/workflows/integration_test_periodic.yaml index bc717cd1..488fc4da 100644 --- a/.github/workflows/integration_test_periodic.yaml +++ b/.github/workflows/integration_test_periodic.yaml @@ -34,6 +34,7 @@ jobs: - name: Install dependencies run: | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - name: Run test_runner.py diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index e59dff34..871cfc01 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -31,5 +31,6 @@ jobs: pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/.github/workflows/unit_test_cpu.yaml b/.github/workflows/unit_test_cpu.yaml index dd318dbb..2482bd51 100644 --- a/.github/workflows/unit_test_cpu.yaml +++ b/.github/workflows/unit_test_cpu.yaml @@ -25,4 +25,5 @@ jobs: pip config --user set global.progress_bar off pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly pytest test --cov=. --cov-report=xml --durations=20 -vv diff --git a/README.md b/README.md index 21634d0b..a8d1fcc4 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 +pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly ``` ### Downloading a tokenizer diff --git a/pyproject.toml b/pyproject.toml index 2a8f9557..a5c1b72f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ keywords = ["pytorch", "training", "llm"] dependencies = [ # Hugging Face integrations - "datasets", + "datasets>=2.19.0", # Tokenization "blobfile", diff --git a/test/__init__.py b/test/__init__.py index e69de29b..2e41cd71 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/__init__.py b/test/datasets/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/test/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/test_checkpoint.py b/test/datasets/test_checkpoint.py new file mode 100644 index 00000000..6f04dd23 --- /dev/null +++ b/test/datasets/test_checkpoint.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtitan.datasets.hf_datasets import build_hf_data_loader +from torchtitan.datasets.tokenizer import create_tokenizer + + +class TestCheckpoint: + def test_c4_resumption(self): + dataset_name = "c4_mini" + dataset_path = "./torchtitan/datasets/c4_mini" + batch_size = 1 + seq_len = 1024 + world_size = 4 + rank = 0 + + dl = self._build_dataloader( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + + it = iter(dl) + for _ in range(250): + next(it) + state = dl.state_dict() + expected_input_ids, expected_labels = next(it) + + # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above + dl = self._build_dataloader( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + dl.load_state_dict(state) + input_ids, labels = next(iter(dl)) + + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(labels, expected_labels) + + def _build_dataloader( + self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ): + tokenizer_type = "tiktoken" + tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + return build_hf_data_loader( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + batch_size=1, + seq_len=1024, + world_size=4, + rank=0, + ) diff --git a/test_runner.py b/test_runner.py index 12085917..0a0be689 100755 --- a/test_runner.py +++ b/test_runner.py @@ -11,6 +11,8 @@ from dataclasses import dataclass from typing import Sequence +from torchtitan.logging_utils import logger + try: import tomllib except ModuleNotFoundError: @@ -211,7 +213,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) - print( + logger.info( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) @@ -223,14 +225,14 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): assert ( dump_folder_arg is not None ), "Can't use seed checkpoint if folder is not specified" - print("Creating seed checkpoint") + logger.info("Creating seed checkpoint") result = _run_cmd( f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" ) - print(result.stdout) + logger.info(result.stdout) result = _run_cmd(cmd) - print(result.stdout) + logger.info(result.stdout) if result.returncode != 0: raise Exception( f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}" diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 81bdf592..fb7c41c8 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,6 +22,7 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger @@ -103,6 +104,7 @@ def __init__( model: nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -118,6 +120,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, + "dataloader": dataloader, } ) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1e13a677..da80b425 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,9 +270,9 @@ def __init__(self): type=int, default=None, help=""" - How many microbatches to split the full training batch into when using pipeline parallelism. + How many microbatches to split the global training batch into when using pipeline parallelism. - The overall training batch size must be evenly divisible by the number of microbatches. + The global training batch size must be evenly divisible by the number of microbatches. The default value will be the number of pipeline stages, if unspecified. """, @@ -500,7 +500,8 @@ def parse_args_from_command_line( "--" + arg, action="store_true" if val else "store_false" ) elif arg == "experimental.pipeline_parallel_split_points": - # type inference breaks here, since the type is just 'list' and it ends up flattening + # without this special case, type inference breaks here, + # since the inferred type is just 'list' and it ends up flattening # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] aux_parser.add_argument("--" + arg, type=string_list) else: diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index f6d09faa..d0306663 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import pickle +from typing import Any, Dict, List, Optional import torch -from torch.utils.data import DataLoader, IterableDataset +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging_utils import logger @@ -23,7 +26,7 @@ } -class HuggingFaceDataset(IterableDataset): +class HuggingFaceDataset(IterableDataset, Stateful): """PyTorch Representation of the HuggingFace Dataset. Args: @@ -99,32 +102,90 @@ def __init__( self.seq_len = seq_len self.infinite = infinite + # variables for checkpointing + self._sample_idx = 0 + self._all_tokens: List[int] = [] + def __iter__(self): max_buffer_token_len = 1 + self.seq_len - all_tokens: List[int] = [] while True: - for sample in iter(self._data): + for sample in self._get_data_iter(): sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] + self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] yield input, label + if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data.") break else: + # Reset offset for the next iteration + self._sample_idx = 0 logger.warning( f"Dataset {self.dataset_name} is being re-looped. " "Loss related metrics might be misleading." ) + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + # Skip samples + if isinstance(self._data, IterableDataset): + it = iter(self._data) + # Naively iterate through the samples as skip may not be supported + for _ in range(self._sample_idx): + next(it) + return it + + # As skipping to the end throws an error in case of map-style dataset, return an empty iterator + if self._sample_idx == len(self._data): + return iter([]) + return iter(self._data.skip(self._sample_idx)) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + + +class DPAwareDataLoader(StatefulDataLoader, Stateful): + """ + A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. + """ + + def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): + super().__init__(hf_ds, batch_size) + self._dp_rank = dp_rank + self._rank_id = f"dp_rank_{dp_rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid, don't log a warning + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning( + f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}." + ) + return + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) + def build_hf_data_loader( dataset_name: str, @@ -140,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DataLoader(hf_ds, batch_size=batch_size) + return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1265495f..61cf79fe 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -35,7 +35,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger -from torchtitan.parallelisms.pipelining_utils import split_stage_fqns # for selective AC no_recompute_list = { @@ -134,19 +133,6 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel -def _llama_fqns(num_layers): - return ( - [ - "tok_embeddings", - ] - + [f"layers.{i}" for i in range(num_layers)] - + [ - "norm", - "output", - ] - ) - - def pipeline_llama( model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict ): @@ -177,9 +163,12 @@ def pipeline_llama_manual( model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict ): """ - This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. - The SPMD parallelisms should be applied to + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. """ pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() @@ -188,74 +177,57 @@ def pipeline_llama_manual( job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp ) stage_idx = pp_rank - this_stage_layer_names = split_stage_fqns( - _llama_fqns(len(model.layers)), - job_config.experimental.pipeline_parallel_split_points, - pp_rank, - ) - if pp_rank < pp_size - 1: - model.norm = None - model.output = None + splits = job_config.experimental.pipeline_parallel_split_points + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None + if pp_rank > 0: model.tok_embeddings = None - names = list(model.layers.keys()) - for name in names: - if f"layers.{name}" not in this_stage_layer_names: + + drop_layers = True + for name in list(model.layers.keys()): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if start_layer is None or f"layers.{name}" == start_layer: + drop_layers = False + if stop_layer is not None and f"layers.{name}" == stop_layer: + drop_layers = True + if drop_layers: del model.layers[name] + if pp_rank < pp_size - 1: + model.norm = None + model.output = None + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the # layers of the model that map to this stage, not the whole model. - + mp_arg = job_config.training.mixed_precision_param + mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32 + batch_size = job_config.training.batch_size + local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) + layers_io_shape = (batch_size, local_seq_len, model_config.dim) + output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size) if pp_rank == 0: # first layer input = torch.randint( model_config.vocab_size, - size=(job_config.training.batch_size, job_config.training.seq_len), + size=(batch_size, job_config.training.seq_len), dtype=torch.int64, device=device, ) else: # later layers (assume all start w/ a transformer layer) - input = torch.rand( - size=( - job_config.training.batch_size, - int(job_config.training.seq_len // parallel_dims.tp), - model_config.dim, - ), - dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] - if parallel_dims.dp_enabled - else torch.float32, - device=device, - ) + input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) if pp_rank == pp_size - 1: # last layer - output = torch.rand( - size=( - job_config.training.batch_size, - int(job_config.training.seq_len // parallel_dims.tp), - model_config.vocab_size, - ), - dtype=torch.float32, - device=device, - ) + output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) else: # earlier layers (assume all end in a transformer layer) - output = torch.rand( - size=( - job_config.training.batch_size, - int(job_config.training.seq_len // parallel_dims.tp), - model_config.dim, - ), - dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] - if parallel_dims.dp_enabled - else torch.float32, - device=device, - ) + output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) model.to_empty(device=device) stage = ManualPipelineStage( diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 2c3d3bcc..24752e4b 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -24,24 +24,3 @@ def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): n_microbatches=stage.chunks, loss_fn=loss_fn, ) - - -def split_stage_fqns(fqns, split_points, stage_id): - """Helper for splitting ordered list of layer names into layers per stage. - - split_points is a list of layer names, each layer will be the first layer in a stage - """ - stages = [] - cur = [] - - for name in fqns: - if name in split_points: - assert len( - cur - ), f"{name} is not a valid split point, do not specify the first layer of stage 0" - stages.append(cur) - cur = [] - cur.append(name) - - stages.append(cur) - return stages[stage_id] diff --git a/train.py b/train.py index 90a745e5..e13acb3d 100644 --- a/train.py +++ b/train.py @@ -268,6 +268,7 @@ def loss_fn(pred, labels): model=model, optimizer=optimizer, lr_scheduler=scheduler, + dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, )