diff --git a/run_generate.py b/run_generate.py index f389770d..ad116d73 100644 --- a/run_generate.py +++ b/run_generate.py @@ -7,9 +7,7 @@ torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4 ``` """ - import argparse -import os from pathlib import Path import torch @@ -21,12 +19,14 @@ ParallelismArgs, get_config_from_file, ) +from nanotron.distributed import get_global_rank from nanotron.generation.decode import ( - GenerationInput, - TokenizerConfig, - decode_text, - decode_tokenized, + GenerationInputs, + GenerationStates, + run_one_inference_step, ) +from nanotron.generation.generate_store import Store +from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model from nanotron.parallel import ParallelContext @@ -34,7 +34,6 @@ from nanotron.parallel.pipeline_parallel.engine import ( OneForwardOneBackwardPipelineEngine, ) -from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.random import ( RandomStates, @@ -50,6 +49,7 @@ except ImportError: AutoTokenizer = None + logger = logging.get_logger(__name__) @@ -57,9 +57,10 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") parser.add_argument("--dp", type=int, default=1) - parser.add_argument("--pp", type=int, default=0) - parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") + parser.add_argument("--use-cache", action="store_true", help="Use cache for generation") return parser.parse_args() @@ -73,9 +74,9 @@ def main(): tokenizer_path = config.tokenizer.tokenizer_name_or_path parallel_config = ParallelismArgs( - dp=args.dp or config.parallelism.dp, - pp=args.pp or config.parallelism.pp, - tp=args.tp or config.parallelism.tp, + dp=args.dp, + pp=args.pp, + tp=args.tp, pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -163,86 +164,147 @@ def main(): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.padding_side = "left" tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + dummy_inputs = [ "The future of AI is", # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", + # "def fib(n)", # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', # "Advancements in technology will lead to", # "Tomorrow's world is shaped by", ] - outputs = decode_text( - input_iter=(GenerationInput(text=text) for text in dummy_inputs), - tokenizer=tokenizer, - # TODO @thomasw21: From ModelWithLoss extract the model. - model=model.model, - parallel_context=parallel_context, - max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - tokenizer_config=TokenizerConfig(max_input_length=None), - is_bench=os.environ.get("USE_BENCH", "0") == "1", + log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0) + + # NOTE: This doesn't support micro-batches and batch inference + device = torch.cuda.current_device() + generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache) + logits_are_batch_first = True + + if generation_config: + if isinstance(generation_config.sampler, str): + sampler_type = SamplerType(generation_config.sampler.upper()) + else: + sampler_type = generation_config.sampler + else: + sampler_type = SamplerType.GREEDY + + tokenized_prompts = tokenizer( + dummy_inputs, + return_tensors="pt", + return_attention_mask=True, + padding=True, ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - - log_rank( - f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", - logger=logger, - level=logging.INFO, - rank=0, - ) + tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) + tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device) + + store = Store() + batch_prompts = None + + for i in range(args.max_new_tokens): + + if generation_config.use_cache: + # Prepare the batch prompts + batch_prompts = GenerationStates( + new_input_ids=tokenized_prompts["input_ids"] + if i == 0 + else tokenized_prompts["input_ids"][:, -1].unsqueeze(0), + new_input_mask=tokenized_prompts["attention_mask"] + if i == 0 + else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0), + store=store, + generation_ids=tokenized_prompts["input_ids"], + generation_mask=tokenized_prompts["attention_mask"], + ) + else: + batch_prompts = GenerationInputs( + input_ids=tokenized_prompts["input_ids"], + input_masks=tokenized_prompts["attention_mask"], + ) - log_rank( - f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", - logger=logger, - level=logging.INFO, - rank=0, + logits = run_one_inference_step( + model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store ) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, + # Sample new token + if parallel_context.is_pipeline_last_stage: + assert logits is not None and isinstance(logits, torch.Tensor) + + # Get sampler + if sampler_type == SamplerType.GREEDY: + sampler = GreedySampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_K: + sampler = TopKSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_P: + sampler = TopPSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.BASIC: + sampler = BasicSampler(pg=parallel_context.tp_pg) + else: + raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") + + if logits_are_batch_first: + logits = logits.transpose(0, 1) + + # Predict next token + next_token = sampler(sharded_logits=logits[:, -1]) + + # Extend the tokenized prompts to insert the new token + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device), + ], + dim=-1, + ) + else: + # Extend the tokenized prompts to receive the new token + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + ( + tokenized_prompts["attention_mask"].shape[0], + tokenized_prompts["attention_mask"].shape[1] + 1, + ), + dtype=torch.bool, + device=device, + ) + + # Broadcast the new token to all the pipeline stages + dist.broadcast( + tokenized_prompts["input_ids"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, ) - else: - outputs = decode_tokenized( - input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), - input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), - model=model.model, - parallel_context=parallel_context, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - max_micro_batch_size=1, - max_new_tokens=12, - returns_logits=False, - ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - log_rank( - f"generation: {generated_ids[len(input_ids) :]}", - logger=logger, - level=logging.INFO, - rank=0, + dist.broadcast( + tokenized_prompts["attention_mask"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, ) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, - ) + # Decode the generated text + if dist.get_rank() == 0: + for i, prompt in enumerate(dummy_inputs): + if generation_config.use_cache: + tokenized_outputs = torch.cat( + [tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1 + ) + outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False) + else: + tokenized_outputs = tokenized_prompts["input_ids"][ + i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : + ] + outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) + + log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) dist.barrier() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index ba4559cf..2630e1d6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -47,6 +47,10 @@ class LlamaConfig: pretraining_tp: int = 1 rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None + rope_theta: float = 10000.0 + rope_interleaved: bool = ( + True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False + ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 6ab71fad..dc021085 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -772,6 +772,67 @@ def generator(): ) +@torch.inference_mode() +def run_one_inference_step(model, batch, parallel_context, device, use_cache, store): + if dist.get_world_size(group=parallel_context.pp_pg) == 1: + if use_cache: + with attach_store(model=model, store=store): + return model.model(batch.new_input_ids, batch.new_input_mask) + return model.model(batch.input_ids, batch.input_masks) + + pipeline_state = PipelineEvalBatchState() + with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): + batch_size = batch.new_input_ids.shape[0] if use_cache else batch.input_ids.shape[0] + seq_len = batch.new_input_ids.shape[1] if use_cache else batch.input_ids.shape[1] + + # Preallocate memory for output logits. + logits = None + if parallel_context.is_pipeline_last_stage: + logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device) + + if use_cache: + batch2use = GenerationStates( + new_input_ids=batch.new_input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + new_input_mask=batch.new_input_mask + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + store=store, + generation_ids=batch.generation_ids, + generation_mask=batch.generation_mask, + ) + with attach_store(model=model, store=store): + output_tensor = model.model(batch2use.new_input_ids, batch2use.new_input_mask) + else: + batch2use = GenerationInputs( + input_ids=batch.input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + input_masks=batch.input_masks + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + ) + + output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) + + nb_send = len(pipeline_state.microbatches_activations_to_send) + assert nb_send <= 2 + for _ in range(nb_send): + # Send activations to the next stage + # Send attention_mask to the next stage + pipeline_state.run_communication() + + # Copy logits. + if parallel_context.is_pipeline_last_stage: + logits = output_tensor + + # Wait for all the communication to complete. + dist.barrier(group=parallel_context.world_pg) + + return logits + + # Distributed utilities def broadcast_tensors( tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..c2c07614 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -117,6 +117,73 @@ def forward( return x_out.type(dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +### llama +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim: int, end: int, theta: float = 500000.0): + super().__init__() + self.dim = dim + self.end = end + self.theta = theta + self.init_rotary_embeddings() + + def init_rotary_embeddings(self): + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, + x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] + position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] + ): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("rotary") + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -188,35 +255,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg @checkpoint_method(attr_name="checkpoint_attention") def forward( self, - query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim] - key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) - kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True + from flash_attn.flash_attn_interface import flash_attn_func # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - attn_output = flash_attn_varlen_func( + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], dropout_p=0.0, softmax_scale=softmax_scale, causal=causal, @@ -317,13 +370,22 @@ def __init__( contiguous_chunks=qkv_contiguous_chunks, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - ) + if config.rope_interleaved: + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + ) + else: + self.rotary_embedding = LlamaRotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + ) + self.rope_interleaved = config.rope_interleaved # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -348,6 +410,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + position_ids: Optional[torch.LongTensor] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -402,8 +465,19 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + + # Rotate half rotary_embedding + # cos, sin = self.rotary_embedding(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # interleaved + if self.rope_interleaved: + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # llama rotary position embedding + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if "key" not in store: # First inference iteration (Prefill) @@ -565,29 +639,14 @@ def forward( # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) - q_sequence_mask = sequence_mask - kv_sequence_mask = sequence_mask - kv_length = key_states.shape[1] - # [batch_size, seq_length, num_heads, d_qk] - # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) attention_output = self.attention( query_states=query_states, key_states=key_states, value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, ) attention_output = ( diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index e04e26f5..aba26dfa 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Tuple, Annotated +from typing import Literal, Tuple import numpy as np import torch @@ -62,6 +62,20 @@ def __init__( self._init_parallel_groups() + self.pipeline_parallel_last_rank = self.pipeline_parallel_size - 1 + self.is_pipeline_first_stage = self.pp_pg.rank() == 0 + self.is_pipeline_last_stage = self.pp_pg.rank() == self.pipeline_parallel_last_rank + self.pipeline_parallel_next_rank = ( + None + if self.is_pipeline_last_stage + else int(self.world_rank_matrix[self.tp_pg.rank(), self.pp_pg.rank() + 1, self.dp_pg.rank()]) + ) + self.pipeline_parallel_prev_rank = ( + None + if self.is_pipeline_first_stage + else int(self.world_rank_matrix[self.tp_pg.rank(), self.pp_pg.rank() - 1, self.dp_pg.rank()]) + ) + def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" dist.barrier() @@ -152,4 +166,4 @@ def get_global_rank( :return: numpy.int64, The global rank. """ - return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank] \ No newline at end of file + return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank] diff --git a/tools/llama3/README.md b/tools/llama3/README.md new file mode 100644 index 00000000..57a31b5e --- /dev/null +++ b/tools/llama3/README.md @@ -0,0 +1,19 @@ +# Llama3 Weight conversion tool +This directory contains the scripts to convert the Llama3 checkpoints from HuggingFace to Nanotron and vice versa. + +- Convert from HuggingFace to Nanotron + +`torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct` +- Convert from Nanotron to HuggingFace + +`torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B` + +In summary, we will do the following: +- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py). +- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/llama.py). +- Copy the parameters layer by layer from one model to the other. +- Store the Nanotron model along with the tokenizer. + +When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated [[1]](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L415), [[2]](https://github.com/huggingface/transformers/blob/1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4/src/transformers/models/llama/modeling_llama.py#L194), while in the Nanotron implementation, they are concatenated [[1b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L310), [[2b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L149). It is crucial to pay attention to these details to convert the models correctly. + +To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic. diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py new file mode 100644 index 00000000..e30610a3 --- /dev/null +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -0,0 +1,266 @@ +""" +torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config = hf_model.config + + # Set Nanotron LlamaConfig + nanotron_llama_config = LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_llama_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.model.layers[i].mlp.gate_proj.weight, + hf_model.model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Llama3"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_llama_config, + ), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_llama_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py new file mode 100644 index 00000000..c5fb1940 --- /dev/null +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -0,0 +1,229 @@ +""" +torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B +""" +import argparse +import os +from dataclasses import asdict +from pathlib import Path + +import torch +from nanotron import logging +from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama import LlamaConfig as LlamaConfigHF + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory with a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--hugging-face-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted checkpoint", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Nanotron checkpoint config + log_rank( + f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", + logger=logger, + level=logging.INFO, + rank=0, + ) + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_llama_config = nanotron_config.model.model_config + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Load Nanotron Checkpoint + log_rank("Loading Nanotron Llama3 Model...", logger=logger, level=logging.INFO, rank=0) + load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) + ) + + # Build empty HF Model + log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0) + hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time + config=LlamaConfigHF(**asdict(nanotron_llama_config)), + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to(DEVICE) + + # Copy params from Nanotron to HF + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + hf_model.model.embed_tokens.weight.copy_( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].input_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight + ) + + # Self attn + # Split Nanotrn qkv projection into q, k, v + q, k, v = torch.split( + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight, + [ + nanotron_llama_config.num_attention_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + ], + ) + assert q.shape == hf_model.model.layers[i].self_attn.q_proj.weight.shape + assert k.shape == hf_model.model.layers[i].self_attn.k_proj.weight.shape + assert v.shape == hf_model.model.layers[i].self_attn.v_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].self_attn.q_proj.weight.copy_(q) + hf_model.model.layers[i].self_attn.k_proj.weight.copy_(k) + hf_model.model.layers[i].self_attn.v_proj.weight.copy_(v) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].self_attn.o_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + gate_proj, up_proj = torch.split( + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight, + split_size_or_sections=[nanotron_llama_config.intermediate_size, nanotron_llama_config.intermediate_size], + ) + assert gate_proj.shape == hf_model.model.layers[i].mlp.gate_proj.weight.shape + assert up_proj.shape == hf_model.model.layers[i].mlp.up_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].mlp.gate_proj.weight.copy_(gate_proj) + hf_model.model.layers[i].mlp.up_proj.weight.copy_(up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].mlp.down_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].post_attention_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) + + log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) + hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) + # Store tokenizer + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokenizer.save_pretrained(args.hugging_face_checkpoint_path) + + log_rank( + f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/llama3/generate_hf_predictions.py b/tools/llama3/generate_hf_predictions.py new file mode 100644 index 00000000..b16774a4 --- /dev/null +++ b/tools/llama3/generate_hf_predictions.py @@ -0,0 +1,81 @@ +""" +torchrun --nproc-per-node 1 tools/llama3/generate_hf_predictions.py --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct +""" +import argparse +import os + +import numpy as np +import torch +from sklearn.metrics import accuracy_score +from transformers import AutoModelForCausalLM, AutoTokenizer + +TXT = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello! Which is the capital of France? What can I visit over there if I go for a week vacation?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris, also known as the City of Light. Paris is a stunning city with a rich history, art, fashion, and cuisine. If you're planning a week-long vacation in Paris, you'll have plenty of time to explore its iconic landmarks, museums, and neighborhoods. Here's a suggested itinerary to get you started: Day 1-2: Iconic Landmarks The Eiffel Tower (Tour Eiffel): The iron lady offers breathtaking views of the city. You can take the stairs or elevator to the top. The Louvre Museum (Musée du Louvre): Home to the Mona Lisa, Venus de Milo, and many other famous artworks. Arc de Triomphe: A monumental arch honoring the soldiers who fought and died for France. Champs-Élysées: A famous avenue lined with cafes, shops, and theaters. Day 3: Montmartre and Sacré-Cœur Explore the charming neighborhood of Montmartre, known for its bohemian vibe, street artists, and stunning views. Visit the Basilique du Sacré-Cœur, a beautiful white church perched on a hill." +SEQ_LENGTH = 512 # For truncating the TXT if GPU can't fit too many tokens + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + + model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) + inputs = tokens[:, :-1] + + with torch.no_grad(): + output = model(inputs) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.logits[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[HF Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + predictions = np.argmax(output.logits.cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + # Results + ## [TP=1] HF 8B: 0.8308823529411765 + ## [TP=2]HF 70B: 0.8860294117647058 + ## [TP=1] HF -> Nanotron -> HF 8B: 0.8308823529411765 + ## [TP=2] HF -> Nanotron -> HF 70B: 0.8860294117647058 + ## [TP=1 --> TP=2] HF -> Nanotron -> Dummy Finetune to change TP=2 -> HF 8B: 0.8308823529411765 + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/llama3/generate_nanotron_predictions.py b/tools/llama3/generate_nanotron_predictions.py new file mode 100644 index 00000000..fbede799 --- /dev/null +++ b/tools/llama3/generate_nanotron_predictions.py @@ -0,0 +1,132 @@ +""" +torchrun --nproc-per-node 2 tools/llama3/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B +""" +import argparse +import os +from pathlib import Path + +import nanotron.distributed as dist +import numpy as np +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from sklearn.metrics import accuracy_score +from transformers import AutoTokenizer + +TXT = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello! Which is the capital of France? What can I visit over there if I go for a week vacation?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris, also known as the City of Light. Paris is a stunning city with a rich history, art, fashion, and cuisine. If you're planning a week-long vacation in Paris, you'll have plenty of time to explore its iconic landmarks, museums, and neighborhoods. Here's a suggested itinerary to get you started: Day 1-2: Iconic Landmarks The Eiffel Tower (Tour Eiffel): The iron lady offers breathtaking views of the city. You can take the stairs or elevator to the top. The Louvre Museum (Musée du Louvre): Home to the Mona Lisa, Venus de Milo, and many other famous artworks. Arc de Triomphe: A monumental arch honoring the soldiers who fought and died for France. Champs-Élysées: A famous avenue lined with cafes, shops, and theaters. Day 3: Montmartre and Sacré-Cœur Explore the charming neighborhood of Montmartre, known for its bohemian vibe, street artists, and stunning views. Visit the Basilique du Sacré-Cœur, a beautiful white church perched on a hill." +SEQ_LENGTH = 512 # For truncating the TXT if GPU can't fit too many tokens + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory containing a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + RANK = dist.get_rank(parallel_context.world_pg) + + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + + model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, # TODO Check with different parallelism if cpu is available + ) + + mark_tied_parameters(model=model, parallel_context=parallel_context) + sanity_check(root_module=model) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) + + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) + inputs = {"input_ids": tokens[:, :-1], "input_mask": torch.ones((1, SEQ_LENGTH), device=DEVICE)} + + model.eval() + + with torch.no_grad(): + output = model.model(**inputs) + + if not RANK: + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + predictions = np.argmax(output.transpose(0, 1).cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + # Results + ## Nanotron 8B, TP 1: 0.8272058823529411 + ## Nanotron 8B, TP 2: 0.7720588235294118 + ## Nanotron 70B, TP 2: 0.8272058823529411 + + +if __name__ == "__main__": + _args = get_args() + main(_args)