diff --git a/.gitignore b/.gitignore index 818b4b76..8ce0ab88 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__ build outputs dist/* +*.model # data data diff --git a/README.md b/README.md index ef5b93d9..9453f676 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,26 @@ Install PyTorch from source or install the latest pytorch nightly, then install pip install -r requirements.txt ``` -Install additional dev requirements if you want to contribute to the repo: +### Downloading a tokenizer.model + +`torchtitan` currently supports training Llama3 (8B, 70B), and Llama2 (13B, 70B) out of the box. To get started training these models, we need to download a tokenizer.model. Follow the instructions on the official [meta-llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to ensure you have access to the Llama model weights. + +Once you have confirmed access, you can run the following command to download the Llama2/3 tokenizer to your local machine. + ``` -pip install -r dev-requirements.txt +# pass your hf_token in order to download tokenizer.model + +# llama3 tokenizer.model +python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=... + +# llama2 tokenizer.model +python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=... ``` -run the llama debug model locally to verify the setup is correct: +Run the llama3 8B model locally on 8 GPUs: ``` -./run_llama_train.sh +CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh ``` diff --git a/requirements.txt b/requirements.txt index 8e089a3e..f70c7566 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch >= 2.2.0.dev -sentencepiece datasets tomli >= 1.1.0 ; python_version < "3.11" tensorboard +sentencepiece +tiktoken diff --git a/torchtitan/datasets/download_tokenizer.py b/torchtitan/datasets/download_tokenizer.py index 41db3c81..eb740d98 100644 --- a/torchtitan/datasets/download_tokenizer.py +++ b/torchtitan/datasets/download_tokenizer.py @@ -4,21 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os from typing import Optional from requests.exceptions import HTTPError -def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: +def hf_download( + repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None +) -> None: from huggingface_hub import hf_hub_download - os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + tokenizer_path = ( + f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model" + ) + try: hf_hub_download( repo_id, - "tokenizer.model", - local_dir="torchtitan/datasets/tokenizer/", + tokenizer_path, + local_dir=local_dir, local_dir_use_symlinks=False, token=hf_token, ) @@ -38,12 +42,24 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - parser.add_argument( "--repo_id", type=str, - default="meta-llama/llama-2-70b", - help="Repository ID to download from.", + default="meta-llama/Meta-Llama-3-8B", + help="Repository ID to download from. default to Llama-3-8B", + ) + parser.add_argument( + "--tokenizer_path", + type=str, + default="", + help="the tokenizer.model path relative to repo_id", ) parser.add_argument( "--hf_token", type=str, default=None, help="HuggingFace API token." ) + parser.add_argument( + "--local_dir", + type=str, + default="torchtitan/datasets/tokenizer/llama3/", + help="local directory to save the tokenizer.model", + ) args = parser.parse_args() - hf_download(args.repo_id, args.hf_token) + hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index b6ffb006..0b8aa015 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -9,7 +9,7 @@ import torch from torch.utils.data import DataLoader, IterableDataset -from torchtitan.datasets.tokenizer import TokenizerIf +from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging_utils import logger from datasets import load_dataset, load_from_disk @@ -29,7 +29,7 @@ class HuggingFaceDataset(IterableDataset): dataset_path (Optional[str]): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded. - tokenizer (TokenizerIf): + tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. seq_len (int): max sequence length world_size (int): number of data parallel processes participating in training @@ -59,7 +59,7 @@ def __init__( self, dataset_name: str, dataset_path: Optional[str], - tokenizer: TokenizerIf, + tokenizer: Tokenizer, seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -132,7 +132,7 @@ def __iter__(self): def build_hf_data_loader( dataset_name: str, dataset_path: Optional[str], - tokenizer: TokenizerIf, + tokenizer: Tokenizer, batch_size: int, seq_len: int, world_size, diff --git a/torchtitan/datasets/tokenizer/__init__.py b/torchtitan/datasets/tokenizer/__init__.py new file mode 100644 index 00000000..346caf83 --- /dev/null +++ b/torchtitan/datasets/tokenizer/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.datasets.tokenizer.sentencepiece import SentencePieceTokenizer +from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer +from torchtitan.datasets.tokenizer.tokenizer import Tokenizer + +from torchtitan.logging_utils import logger + + +def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: + logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}") + if tokenizer_type == "sentencepiece": + return SentencePieceTokenizer(tokenizer_path) + elif tokenizer_type == "tiktoken": + return TikTokenizer(tokenizer_path) + else: + raise ValueError(f"Unknown tokenizer type: {args.type}") diff --git a/torchtitan/datasets/tokenizer.py b/torchtitan/datasets/tokenizer/sentencepiece.py similarity index 68% rename from torchtitan/datasets/tokenizer.py rename to torchtitan/datasets/tokenizer/sentencepiece.py index 9b3ef49f..7229daa3 100644 --- a/torchtitan/datasets/tokenizer.py +++ b/torchtitan/datasets/tokenizer/sentencepiece.py @@ -6,46 +6,15 @@ # copied and adjusted from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py -import os -from abc import ABC, abstractmethod from typing import List from sentencepiece import SentencePieceProcessor +from torchtitan.datasets.tokenizer.tokenizer import Tokenizer from torchtitan.logging_utils import logger -class TokenizerIf(ABC): - # tokenizer interface - def __init__(self, tokenizer_path: str): - assert os.path.exists( - tokenizer_path - ), f"The tokenizer path does not exist: {tokenizer_path}" - assert os.path.isfile(tokenizer_path), tokenizer_path - self._n_words = 8 - - @abstractmethod - def encode(self, *args, **kwargs) -> List[int]: - ... - - @abstractmethod - def decode(self, *args, **kwargs) -> str: - ... - - @property - def n_words(self) -> int: - return self._n_words - - -def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf: - logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}") - if tokenizer_type == "sentencepiece": - return SentencePieceTokenizer(tokenizer_path) - else: - raise ValueError(f"Unknown tokenizer type: {args.type}") - - -class SentencePieceTokenizer(TokenizerIf): +class SentencePieceTokenizer(Tokenizer): """ Tokenizing and encoding/decoding text based on a SentencePiece model. diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py new file mode 100644 index 00000000..c6b58f35 --- /dev/null +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from pathlib import Path +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, +) + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.datasets.tokenizer.tokenizer import Tokenizer +from torchtitan.logging_utils import logger + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__(model_path) + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + logger.info(f"Reloaded tiktoken model from {model_path}") + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"#words: {self._n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None, + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] diff --git a/torchtitan/datasets/tokenizer/tokenizer.model b/torchtitan/datasets/tokenizer/tokenizer.model deleted file mode 100644 index 22bccbcb..00000000 Binary files a/torchtitan/datasets/tokenizer/tokenizer.model and /dev/null differ diff --git a/torchtitan/datasets/tokenizer/tokenizer.py b/torchtitan/datasets/tokenizer/tokenizer.py new file mode 100644 index 00000000..128dc6d2 --- /dev/null +++ b/torchtitan/datasets/tokenizer/tokenizer.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from abc import ABC, abstractmethod +from typing import List + + +class Tokenizer(ABC): + # basic tokenizer interface, for typing purpose mainly + def __init__(self, tokenizer_path: str): + assert os.path.exists( + tokenizer_path + ), f"The tokenizer path does not exist: {tokenizer_path}" + assert os.path.isfile(tokenizer_path), tokenizer_path + self._n_words = 8 + + @abstractmethod + def encode(self, *args, **kwargs) -> List[int]: + ... + + @abstractmethod + def decode(self, *args, **kwargs) -> str: + ... + + @property + def n_words(self) -> int: + return self._n_words diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index a42fe4a1..d5f95355 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch.nn as nn + from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger -from torchtitan.models.llama import Transformer -def build_fp8_linear(model: Transformer, job_config: JobConfig): +def build_fp8_linear(model: nn.Module, job_config: JobConfig): """ This function converts the linear layers to one of the fp8 types: - Float8DynamicLinear: Dynamic quantization of the weights and the activations diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 5d982729..c7bb16c6 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,16 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.models.llama import llama_configs, Transformer +from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer models_config = { - "llama": llama_configs, + "llama2": llama2_configs, + "llama3": llama3_configs, } -model_name_to_cls = { - "llama": Transformer, -} +model_name_to_cls = {"llama2": Transformer, "llama3": Transformer} model_name_to_tokenizer = { - "llama": "sentencepiece", + "llama2": "sentencepiece", + "llama3": "tiktoken", } diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index e6e0122c..2393d92f 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -11,7 +11,7 @@ __all__ = ["Transformer"] -llama_configs = { +llama2_configs = { "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), @@ -27,3 +27,25 @@ multiple_of=4096, ), } + +llama3_configs = { + "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16, rope_theta=500000), + "8B": ModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": ModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), +} diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index a647a2d0..6ded3306 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -27,6 +27,7 @@ class ModelArgs: multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 + rope_theta: float = 10000 max_batch_size: int = 32 max_seq_len: int = 2048 @@ -366,6 +367,7 @@ def __init__(self, model_args: ModelArgs): # Need to compute until at least the max token limit for generation # (use 2x max sequence length to be safe) model_args.max_seq_len * 2, + model_args.rope_theta, ), persistent=True, ) @@ -399,6 +401,7 @@ def init_weights(self): # Need to compute until at least the max token limit for generation # (use 2x max sequence length to be safe) self.model_args.max_seq_len * 2, + self.model_args.rope_theta, ) nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers: diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 3d11a68c..fba9bff8 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -12,7 +12,8 @@ from torchtitan.parallelisms.parallelize_llama import parallelize_llama models_parallelize_fns = { - "llama": parallelize_llama, + "llama2": parallelize_llama, + "llama3": parallelize_llama, } diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 5f5f0cc4..7a7e0de1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -2,8 +2,9 @@ [job] dump_folder = "./outputs" -description = "LLaMA debug training" -use_for_integration_test = true +description = "Llama debug training" +# TODO: turn this back on once ci have tokenizer +use_for_integration_test = false [profiling] enable_profiling = false @@ -17,10 +18,10 @@ enable_tensorboard = true save_tb_folder = "tb" [model] -name = "llama" +name = "llama3" flavor = "debugmodel" norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm -tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" +tokenizer_path = "./torchtitan/datasets/tokenizer/llama3/tokenizer.model" [optimizer] name = "AdamW" diff --git a/train_configs/llama_13b.toml b/train_configs/llama2_13b.toml similarity index 95% rename from train_configs/llama_13b.toml rename to train_configs/llama2_13b.toml index cdd9a3f0..67d5ecaf 100644 --- a/train_configs/llama_13b.toml +++ b/train_configs/llama2_13b.toml @@ -3,7 +3,7 @@ [job] dump_folder = "./outputs" -description = "LLaMA 13B training" +description = "Llama2 13B training" [profiling] enable_profiling = true @@ -16,7 +16,7 @@ enable_tensorboard = true save_tb_folder = "tb" [model] -name = "llama" +name = "llama2" flavor = "13B" norm_type = "fused_rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" diff --git a/train_configs/llama_70b.toml b/train_configs/llama2_70b.toml similarity index 95% rename from train_configs/llama_70b.toml rename to train_configs/llama2_70b.toml index e5557937..22d71c81 100644 --- a/train_configs/llama_70b.toml +++ b/train_configs/llama2_70b.toml @@ -3,7 +3,7 @@ [job] dump_folder = "./outputs" -description = "LLaMA 70B training" +description = "Llama2 70B training" [profiling] enable_profiling = true @@ -16,7 +16,7 @@ enable_tensorboard = true save_tb_folder = "tb" [model] -name = "llama" +name = "llama2" flavor = "70B" norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" diff --git a/train_configs/llama_7b.toml b/train_configs/llama2_7b.toml similarity index 95% rename from train_configs/llama_7b.toml rename to train_configs/llama2_7b.toml index f121d570..fe244cd0 100644 --- a/train_configs/llama_7b.toml +++ b/train_configs/llama2_7b.toml @@ -2,7 +2,7 @@ [job] dump_folder = "./outputs" -description = "LLaMA 7B training" +description = "Llama2 7B training" [profiling] enable_profiling = true @@ -15,7 +15,7 @@ enable_tensorboard = true save_tb_folder = "tb" [model] -name = "llama" +name = "llama2" flavor = "7B" norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml new file mode 100644 index 00000000..4c757f7e --- /dev/null +++ b/train_configs/llama3_70b.toml @@ -0,0 +1,50 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 70B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "70B" +norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] +tokenizer_path = "./torchtitan/datasets/tokenizer/llama3/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 1.5e-4 + +[training] +batch_size = 16 +seq_len = 8192 +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 1000 +data_parallel_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +pipeline_parallel_degree = 1 +fp8_linear = "" +compile = false +dataset = "c4" + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval_type = "steps" +interval = 500 +model_weights_only = false +export_dtype = "float32" + +[activation_checkpoint] +mode = 'full' diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml new file mode 100644 index 00000000..257a9732 --- /dev/null +++ b/train_configs/llama3_8b.toml @@ -0,0 +1,48 @@ +# torchtitan Config.toml +[job] +dump_folder = "./outputs" +description = "Llama 3 8B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profiling/traces" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "8B" +norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +tokenizer_path = "./torchtitan/datasets/tokenizer/llama3/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 3e-4 + +[training] +batch_size = 2 +seq_len = 8192 +warmup_steps = 200 # lr scheduler warm up +max_norm = 1.0 # grad norm clipping +steps = 1000 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +fp8_linear = "" +compile = false +dataset = "c4" + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval_type = "steps" +interval = 500 +model_weights_only = false +export_dtype = "float32" + +[activation_checkpoint] +mode="full"