Skip to content

Commit

Permalink
Support Llama3 8b/70b (pytorch#256)
Browse files Browse the repository at this point in the history
This PR adds support for Llama3 8b/70b, mainly it:
- add tiktonizer, add instructions to download tokenizer
- add options for the llama model to support Llama3
- add Llama3 8b/70b configs
  • Loading branch information
wanchaol authored Apr 20, 2024
1 parent 9642f58 commit a46079f
Show file tree
Hide file tree
Showing 21 changed files with 440 additions and 70 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__
build
outputs
dist/*
*.model

# data
data
Expand Down
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,26 @@ Install PyTorch from source or install the latest pytorch nightly, then install
pip install -r requirements.txt
```

Install additional dev requirements if you want to contribute to the repo:
### Downloading a tokenizer.model

`torchtitan` currently supports training Llama3 (8B, 70B), and Llama2 (13B, 70B) out of the box. To get started training these models, we need to download a tokenizer.model. Follow the instructions on the official [meta-llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to ensure you have access to the Llama model weights.

Once you have confirmed access, you can run the following command to download the Llama2/3 tokenizer to your local machine.

```
pip install -r dev-requirements.txt
# pass your hf_token in order to download tokenizer.model
# llama3 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
# llama2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
```

run the llama debug model locally to verify the setup is correct:
Run the llama3 8B model locally on 8 GPUs:

```
./run_llama_train.sh
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
```


Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch >= 2.2.0.dev
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
tiktoken
32 changes: 24 additions & 8 deletions torchtitan/datasets/download_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

from requests.exceptions import HTTPError


def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
def hf_download(
repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None
) -> None:
from huggingface_hub import hf_hub_download

os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
tokenizer_path = (
f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model"
)

try:
hf_hub_download(
repo_id,
"tokenizer.model",
local_dir="torchtitan/datasets/tokenizer/",
tokenizer_path,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=hf_token,
)
Expand All @@ -38,12 +42,24 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
parser.add_argument(
"--repo_id",
type=str,
default="meta-llama/llama-2-70b",
help="Repository ID to download from.",
default="meta-llama/Meta-Llama-3-8B",
help="Repository ID to download from. default to Llama-3-8B",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="",
help="the tokenizer.model path relative to repo_id",
)
parser.add_argument(
"--hf_token", type=str, default=None, help="HuggingFace API token."
)
parser.add_argument(
"--local_dir",
type=str,
default="torchtitan/datasets/tokenizer/llama3/",
help="local directory to save the tokenizer.model",
)

args = parser.parse_args()
hf_download(args.repo_id, args.hf_token)
hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token)
8 changes: 4 additions & 4 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch.utils.data import DataLoader, IterableDataset

from torchtitan.datasets.tokenizer import TokenizerIf
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger

from datasets import load_dataset, load_from_disk
Expand All @@ -29,7 +29,7 @@ class HuggingFaceDataset(IterableDataset):
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
tokenizer (TokenizerIf):
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
Expand Down Expand Up @@ -132,7 +132,7 @@ def __iter__(self):
def build_hf_data_loader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size,
Expand Down
21 changes: 21 additions & 0 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.tokenizer.sentencepiece import SentencePieceTokenizer
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer

from torchtitan.logging_utils import logger


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
elif tokenizer_type == "tiktoken":
return TikTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,15 @@

# copied and adjusted from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py

import os
from abc import ABC, abstractmethod
from typing import List

from sentencepiece import SentencePieceProcessor

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger


class TokenizerIf(ABC):
# tokenizer interface
def __init__(self, tokenizer_path: str):
assert os.path.exists(
tokenizer_path
), f"The tokenizer path does not exist: {tokenizer_path}"
assert os.path.isfile(tokenizer_path), tokenizer_path
self._n_words = 8

@abstractmethod
def encode(self, *args, **kwargs) -> List[int]:
...

@abstractmethod
def decode(self, *args, **kwargs) -> str:
...

@property
def n_words(self) -> int:
return self._n_words


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")


class SentencePieceTokenizer(TokenizerIf):
class SentencePieceTokenizer(Tokenizer):
"""
Tokenizing and encoding/decoding text based on a SentencePiece model.
Expand Down
Loading

0 comments on commit a46079f

Please sign in to comment.