Skip to content

Commit

Permalink
Add Llama Training scripts (#10)
Browse files Browse the repository at this point in the history
* more to do

* why model no worky

* looking better

* little cleaner, overfit works

* cool
  • Loading branch information
drisspg authored Dec 20, 2023
1 parent 8054124 commit 6db36c3
Show file tree
Hide file tree
Showing 9 changed files with 878 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ufmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install ufmt
pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1
- name: Analyzing the code with ufmt
run: |
ufmt check .
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# benchmarks / data will be used for experiment results
benchmarks/data/*
transformer_nuggets/llama/data/*
.vscode

# Byte-compiled / optimized / DLL files
Expand Down
16 changes: 12 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ classifiers = [
]

dependencies = [
"torch >= 2.0.1",
"torch >= 2.1.1",
"scipy >= 1.9.1",
"tqdm",
"tabulate"
]

[project.optional-dependencies]
dev = [
"black",
"usort",
"libcst",
"black==23.3.0",
"usort==1.0.6",
"ufmt==2.1.0",
"libcst==1.0.1",
"bumpver",
"pip-tools",
"pytest"
Expand All @@ -37,6 +38,13 @@ dev = [
qlora = ['bitsandbytes']
flash = ['triton']

llama = [
"sentencepiece==0.1.99",
"datasets==2.15.0",
"fire==0.5.0",
"float8_experimental",
]

[tool.usort]
first_party_detection = false

Expand Down
44 changes: 44 additions & 0 deletions transformer_nuggets/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Llama Pretraining

This directory contains the code for pretraining Llama. The model definition is from [gpt-fast](https://github.com/pytorch-labs/gpt-fast). It is slightly modified to remove the kvcache since this is not needed during pre-training.

The Tokenizer is from the original [LLama repo](https://github.com/facebookresearch/llama) and uses sentencepiece under the hood. Instead of training the tokenizer from scratch the tokenizer.bin file from llama2 release is used.

The training loop can be found in `train.py`. It expects that the `prepare_data.py` script has been run to generate the training data. The training data is expected to be in the `data/` directory.

### Usage

#### Install dependencies
``` Shell
pip install -e .
pip install -e ".[llama]"
```
Get the Llama2 tokenizer, file and place inside the `llama/data` directory.

The following paths are assumed you are in the top level `transformer_nuggets/` directory.

#### Prepare Data

Then run the following command:
``` Shell
python transformer_nuggets/llama/prepare_data.py \
--tokenizer_path=transformer_nuggets/llama/data/tokenizer.model \
--output_dir=transformer_nuggets/llama/data/
```
This should take around 3 minutes to run and prepare the training data.

#### Train Model
To edit the training configs take a look at `transformer_nuggets/llama/train.py`. The `entrypoint` function constructs the hyperparam configs as well as the
training configs. By default this will train a 7b model and and save the checkpoints to `transformer_nuggets/llama/data/out/`. It will also save the loss
logs to `transformer_nuggets/llama/data/logs`.


To tain the model using delayed scaling with torch compile run the command
``` Shell
python transformer_nuggets/llama/train.py \
--fp8_linear_type "delayed" --compile True
```


### Notes
To get the Llama2 tokenizer go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access. This will get you pretrained weights as well as the tokenizer.
Empty file.
265 changes: 265 additions & 0 deletions transformer_nuggets/llama/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [
config
for config in transformer_configs
if config in str(name).upper() or config in str(name)
]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])


transformer_configs = {
"CodeLlama-7b-Python-hf": dict(
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
"34B": dict(
n_layer=48,
n_head=64,
dim=8192,
vocab_size=32000,
n_local_heads=8,
intermediate_size=22016,
rope_base=1000000,
), # CodeLlama-34B-Python-hf
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
}


class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length, device: torch.device):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size

self.freqs_cis = precompute_freqs_cis(
max_seq_length,
head_dim,
device,
self.config.rope_base,
)

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis)
x = self.norm(x)
logits = self.output(x)
return logits

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))

def init_parameters(self):
"""Initialize the parameters, taken from nanogpt"""

def _init_weights(module: nn.Module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

self.apply(_init_weights)


class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
return out


class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y


class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(
seq_len: int, n_elem: int, device: torch.device, base: int = 10000
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16, device=device)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
Loading

0 comments on commit 6db36c3

Please sign in to comment.