Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Llama3 8b/70b #256

Merged
merged 2 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__
build
outputs
dist/*
*.model

# data
data
Expand Down
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,26 @@ Install PyTorch from source or install the latest pytorch nightly, then install
pip install -r requirements.txt
```

Install additional dev requirements if you want to contribute to the repo:
### Downloading a tokenizer.model

`torchtitan` currently supports training Llama3 (8B, 70B), and Llama2 (13B, 70B) out of the box. To get started training these models, we need to download a tokenizer.model. Follow the instructions on the official [meta-llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to ensure you have access to the Llama model weights.

Once you have confirmed access, you can run the following command to download the Llama2/3 tokenizer to your local machine.

```
pip install -r dev-requirements.txt
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
# pass your hf_token in order to download tokenizer.model

# llama3 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
```

run the llama debug model locally to verify the setup is correct:
Run the llama3 8B model locally on 8 GPUs:

```
./run_llama_train.sh
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
```


Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch >= 2.2.0.dev
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
tiktoken
32 changes: 24 additions & 8 deletions torchtitan/datasets/download_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

from requests.exceptions import HTTPError


def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
def hf_download(
repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None
) -> None:
from huggingface_hub import hf_hub_download

os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
tokenizer_path = (
f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model"
)

try:
hf_hub_download(
repo_id,
"tokenizer.model",
local_dir="torchtitan/datasets/tokenizer/",
tokenizer_path,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=hf_token,
)
Expand All @@ -38,12 +42,24 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
parser.add_argument(
"--repo_id",
type=str,
default="meta-llama/llama-2-70b",
help="Repository ID to download from.",
default="meta-llama/Meta-Llama-3-8B",
help="Repository ID to download from. default to Llama-3-8B",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="",
help="the tokenizer.model path relative to repo_id",
)
parser.add_argument(
"--hf_token", type=str, default=None, help="HuggingFace API token."
)
parser.add_argument(
"--local_dir",
type=str,
default="torchtitan/datasets/tokenizer/llama3/",
help="local directory to save the tokenizer.model",
)

args = parser.parse_args()
hf_download(args.repo_id, args.hf_token)
hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token)
8 changes: 4 additions & 4 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch.utils.data import DataLoader, IterableDataset

from torchtitan.datasets.tokenizer import TokenizerIf
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger

from datasets import load_dataset, load_from_disk
Expand All @@ -29,7 +29,7 @@ class HuggingFaceDataset(IterableDataset):
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
tokenizer (TokenizerIf):
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
Expand Down Expand Up @@ -132,7 +132,7 @@ def __iter__(self):
def build_hf_data_loader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size,
Expand Down
21 changes: 21 additions & 0 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.tokenizer.sentencepiece import SentencePieceTokenizer
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer

from torchtitan.logging_utils import logger


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
elif tokenizer_type == "tiktoken":
return TikTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,15 @@

# copied and adjusted from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py

import os
from abc import ABC, abstractmethod
from typing import List

from sentencepiece import SentencePieceProcessor

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger


class TokenizerIf(ABC):
# tokenizer interface
def __init__(self, tokenizer_path: str):
assert os.path.exists(
tokenizer_path
), f"The tokenizer path does not exist: {tokenizer_path}"
assert os.path.isfile(tokenizer_path), tokenizer_path
self._n_words = 8

@abstractmethod
def encode(self, *args, **kwargs) -> List[int]:
...

@abstractmethod
def decode(self, *args, **kwargs) -> str:
...

@property
def n_words(self) -> int:
return self._n_words


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")


class SentencePieceTokenizer(TokenizerIf):
class SentencePieceTokenizer(Tokenizer):
"""
Tokenizing and encoding/decoding text based on a SentencePiece model.

Expand Down
Loading
Loading