Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 conversion scripts 🦙 #174

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
53 changes: 13 additions & 40 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -189,35 +189,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func

# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])

# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
from flash_attn.flash_attn_interface import flash_attn_func

# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attn_output = flash_attn_varlen_func(
# For now we are assuming that we use causual mask. No magic here
causal = True
attn_output = flash_attn_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
Expand Down Expand Up @@ -325,7 +311,9 @@ def __init__(
)

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta
)

self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
Expand Down Expand Up @@ -567,29 +555,14 @@ def forward(
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)

q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask

kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]

key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk)
value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v)

attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)

attention_output = (
Expand Down
32 changes: 32 additions & 0 deletions tools/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Llama3 Weight conversion tool
This directory contains the scripts to convert the Llama3 checkpoints from HuggingFace to Nanotron and vice versa.

## Downloading Llama3 weights
We will use the Llama3 checkpoints stored in the HuggingFace Hub for the conversion. Despite being able to download the checkpoints setting `--pretrained-model-name-or-pathmeta-llama/Meta-Llama-3-8B-Instruct`, this is not recommended since it will download the pretrained weights to the [HuggingFace Cache](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache). We encourage to download the checkpoints explicityly to a folder with the following script:
```python
from huggingface_hub import snapshot_download

snapshot_download(repo_id="meta-llama/Meta-Llama-3-8B",
local_dir = "models/Meta-Llama-3-8B",
local_dir_use_symlinks=False,
ignore_patterns=["original/*"]) # Llama3 models in the Hub contain the original checkpoints. We just want the HF checkpoint stored in the safetensor format
```

## Conversion

- Convert from HuggingFace to Nanotron

`torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct`
- Convert from Nanotron to HuggingFace

`torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B`

In summary, we will do the following:
- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py).
- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/llama.py).
- Copy the parameters layer by layer from one model to the other.
- Store the Nanotron model along with the tokenizer.

When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated [[1]](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L415), [[2]](https://github.com/huggingface/transformers/blob/1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4/src/transformers/models/llama/modeling_llama.py#L194), while in the Nanotron implementation, they are concatenated [[1b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L310), [[2b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L149). It is crucial to pay attention to these details to convert the models correctly.

To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic.
Loading
Loading