Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 21, 2024
1 parent 301818b commit 79e799f
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch >= 2.2.0.dev
datasets
datasets >= 2.19.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/integration_test_periodic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
- name: Run test_runner.py
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded
1 change: 1 addition & 0 deletions .github/workflows/unit_test_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ jobs:
pip config --user set global.progress_bar off
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
pytest test --cov=. --cov-report=xml --durations=20 -vv
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ authors = [
keywords = ["pytorch", "training", "llm"]
dependencies = [
# Hugging Face integrations
"datasets",
"datasets>=2.19.0",

# Tokenization
"blobfile",
Expand Down
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions test/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
54 changes: 54 additions & 0 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer


class TestCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
batch_size = 1
seq_len = 1024
world_size = 4
rank = 0

dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)

it = iter(dl)
for _ in range(250):
next(it)
state = dl.state_dict()
expected_input_ids, expected_labels = next(it)

# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)
dl.load_state_dict(state)
input_ids, labels = next(iter(dl))

assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
)
10 changes: 6 additions & 4 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dataclasses import dataclass
from typing import Sequence

from torchtitan.logging_utils import logger

try:
import tomllib
except ModuleNotFoundError:
Expand Down Expand Up @@ -191,7 +193,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)

Expand All @@ -203,14 +205,14 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
assert (
dump_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
logger.info("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
)
print(result.stdout)
logger.info(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
logger.info(result.stdout)
if result.returncode != 0:
raise Exception(
f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand All @@ -118,6 +120,7 @@ def __init__(
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"dataloader": dataloader,
}
)

Expand Down
7 changes: 4 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ def __init__(self):
type=int,
default=None,
help="""
How many microbatches to split the full training batch into when using pipeline parallelism.
How many microbatches to split the global training batch into when using pipeline parallelism.
The overall training batch size must be evenly divisible by the number of microbatches.
The global training batch size must be evenly divisible by the number of microbatches.
The default value will be the number of pipeline stages, if unspecified.
""",
Expand Down Expand Up @@ -500,7 +500,8 @@ def parse_args_from_command_line(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
# type inference breaks here, since the type is just 'list' and it ends up flattening
# without this special case, type inference breaks here,
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
Expand Down
81 changes: 71 additions & 10 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
import pickle
from typing import Any, Dict, List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
Expand All @@ -23,7 +26,7 @@
}


class HuggingFaceDataset(IterableDataset):
class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
Expand Down Expand Up @@ -99,32 +102,90 @@ def __init__(
self.seq_len = seq_len
self.infinite = infinite

# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

while True:
for sample in iter(self._data):
for sample in self._get_data_iter():
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
return iter([])
return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}."
)
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
dataset_name: str,
Expand All @@ -140,4 +201,4 @@ def build_hf_data_loader(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
Loading

0 comments on commit 79e799f

Please sign in to comment.