From 9578b69a6a0161cb2fe51612cc38b1d16eeb296f Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 09:16:09 +0000 Subject: [PATCH 01/18] Add HF Generation script --- generate_hf_predictions.py | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 generate_hf_predictions.py diff --git a/generate_hf_predictions.py b/generate_hf_predictions.py new file mode 100644 index 00000000..5f7ec395 --- /dev/null +++ b/generate_hf_predictions.py @@ -0,0 +1,59 @@ +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." +SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # TODO Refractor with HF pipeline or .generate()? + hf_model = ( + AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + .to("cuda") + .eval() + ) + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") + inputs_hf = tokens[:, :-1] + + with torch.no_grad(): + output_hf = hf_model(inputs_hf) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + + for predicted_token in predicted_tokens: + next_tokens_hf = torch.softmax(output_hf.logits[0, predicted_token, :], -1) + hf_topk_next_tokens = torch.topk(next_tokens_hf, 10) + + print( + *[ + f"[HF Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(hf_topk_next_tokens.indices, hf_topk_next_tokens.values) + ], + sep="\n", + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) From a7f918c35753e3e6f202a2babf3b3b0931e0bafa Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 09:36:30 +0000 Subject: [PATCH 02/18] Add Nanotron Generation Script --- generate_hf_predictions.py | 16 ++-- generate_nanotron_predictions.py | 122 +++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 6 deletions(-) create mode 100644 generate_nanotron_predictions.py diff --git a/generate_hf_predictions.py b/generate_hf_predictions.py index 5f7ec395..5fd5bc3f 100644 --- a/generate_hf_predictions.py +++ b/generate_hf_predictions.py @@ -1,4 +1,5 @@ import argparse +import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -24,7 +25,7 @@ def get_args(): def main(args): # TODO Refractor with HF pipeline or .generate()? - hf_model = ( + model = ( AutoModelForCausalLM.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) @@ -34,21 +35,24 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") - inputs_hf = tokens[:, :-1] + inputs = tokens[:, :-1] with torch.no_grad(): - output_hf = hf_model(inputs_hf) + output = model(inputs) predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) for predicted_token in predicted_tokens: - next_tokens_hf = torch.softmax(output_hf.logits[0, predicted_token, :], -1) - hf_topk_next_tokens = torch.topk(next_tokens_hf, 10) + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.logits[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) print( *[ f"[HF Model] Next token: {idx.item()}, probability: {prob}" - for idx, prob in zip(hf_topk_next_tokens.indices, hf_topk_next_tokens.values) + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) ], sep="\n", ) diff --git a/generate_nanotron_predictions.py b/generate_nanotron_predictions.py new file mode 100644 index 00000000..ff613f4d --- /dev/null +++ b/generate_nanotron_predictions.py @@ -0,0 +1,122 @@ +import argparse +import os + +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from transformers import AutoTokenizer + +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +DP = 1 +PP = 1 + +TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." +SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory containing a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + group = parser.add_argument_group(title="Tokenizer") + group.add_argument( + "--tokenizer-name-or-path", + type=str, + required=True, + help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.", + ) + + args = parser.parse_args() + + return args + + +def main(args): + + parallel_config = ParallelismArgs( + dp=DP, + pp=PP, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + + model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=torch.bfloat16, + device=torch.device("cuda"), # TODO Check with different parallelism + ) + + mark_tied_parameters(model=model, parallel_context=parallel_context) + sanity_check(root_module=model) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model, parallel_context=parallel_context, root_folder=args.nanotron_checkpoint_path) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") + inputs = {"input_ids": tokens[:, :-1], "input_mask": torch.ones((1, SEQ_LENGTH), device="cuda")} + + model.eval() + + with torch.no_grad(): + output = model(inputs) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) From 7728482ff717470fdf544724b2fd595bc441c846 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 11:11:09 +0000 Subject: [PATCH 03/18] Add HF to Nanotron conversion script --- convert_hf_to_nanotron.py | 248 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 convert_hf_to_nanotron.py diff --git a/convert_hf_to_nanotron.py b/convert_hf_to_nanotron.py new file mode 100644 index 00000000..3b66cef4 --- /dev/null +++ b/convert_hf_to_nanotron.py @@ -0,0 +1,248 @@ +""" +torchrun --nproc-per-node 1 convert_hf_to_nanotron.py --tp 1 --nanotron-checkpoint-path n_c/second --pretrained-model-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron.config import Config, GeneralArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from transformers import AutoModelForCausalLM + +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +DP = 1 +PP = 1 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Load Llama3-8B HF model + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ).to("cuda") + hf_config = hf_model.config + + # Set Nanotron LlamaConfig + nanotron_llama_config = LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + + # Init Llama3-8B Nanotron model + parallel_config = ParallelismArgs( + dp=DP, + pp=PP, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_llama_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=torch.bfloat16, + device=torch.device("cuda"), + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + # Token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in range(nanotron_llama_config.num_hidden_layers): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.model.layers[i].mlp.gate_proj.weight, + hf_model.model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="conversion", run="Llama3-8B"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_llama_config, + ), + tokenizer=TokenizerArgs(args.pretrained_model_name_or_path), + ) + print("Saving config ...") + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + print("Saving model config ...") + json.dump(asdict(nanotron_llama_config), f) + + +if __name__ == "__main__": + _args = get_args() + main(_args) From 4107ed4a2ce8cf7cf05aa27ad9201265421314cb Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 11:12:05 +0000 Subject: [PATCH 04/18] Add Nanorton to HF conversion script --- convert_nanotron_to_hf.py | 213 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 convert_nanotron_to_hf.py diff --git a/convert_nanotron_to_hf.py b/convert_nanotron_to_hf.py new file mode 100644 index 00000000..363e0c51 --- /dev/null +++ b/convert_nanotron_to_hf.py @@ -0,0 +1,213 @@ +""" +torchrun --nproc-per-node 1 convert_nanotron_to_hf.py --tp 1 --nanotron-checkpoint-path n_c/second --hugging-face-checkpoint-path hf_c/second +""" +import argparse +import os +from dataclasses import asdict +from pathlib import Path + +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama import LlamaConfig as LlamaConfigHF + +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +DP = 1 +PP = 1 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory with a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--hugging-face-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted checkpoint", + ) + # TODO Add push to hub + + args = parser.parse_args() + + return args + + +def main(args): + # Load Nanotron checkpoint config + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_llama_config = nanotron_config.model.model_config + + # Init Llama3-8B Nanotron model + parallel_config = ParallelismArgs( + dp=DP, + pp=PP, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=torch.bfloat16, + device=torch.device("cuda"), + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Load Nanotron Checkpoint + load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) + ) + + # Build empty HF Model + ## TODO This takes pretty long time + hf_model = AutoModelForCausalLM.from_config( + config=LlamaConfigHF(**asdict(nanotron_llama_config)), + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ).to("cuda") + + # Copy params from Nanotron to HF + # Token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + hf_model.model.embed_tokens.weight.copy_( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight + ) + + # Decoder layers + for i in range(nanotron_config.model.model_config.num_hidden_layers): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].input_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight + ) + + # Self attn + # Split Nanotrn qkv projection into q, k, v + q, k, v = torch.split( + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight, + [ + nanotron_llama_config.num_attention_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + ], + ) + assert q.shape == hf_model.model.layers[i].self_attn.q_proj.weight.shape + assert k.shape == hf_model.model.layers[i].self_attn.k_proj.weight.shape + assert v.shape == hf_model.model.layers[i].self_attn.v_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].self_attn.q_proj.weight.copy_(q) + hf_model.model.layers[i].self_attn.k_proj.weight.copy_(k) + hf_model.model.layers[i].self_attn.v_proj.weight.copy_(v) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].self_attn.o_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + gate_proj, up_proj = torch.split( + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight, + split_size_or_sections=[nanotron_llama_config.intermediate_size, nanotron_llama_config.intermediate_size], + ) + assert gate_proj.shape == hf_model.model.layers[i].mlp.gate_proj.weight.shape + assert up_proj.shape == hf_model.model.layers[i].mlp.up_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].mlp.gate_proj.weight.copy_(gate_proj) + hf_model.model.layers[i].mlp.up_proj.weight.copy_(up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].mlp.down_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].post_attention_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight + ) + + # Last layer norm + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) + + # LM_Head + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) + + # Store weights + hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) + # Store tokenizer + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokenizer.save_pretrained(args.hugging_face_checkpoint_path) + + +if __name__ == "__main__": + _args = get_args() + main(_args) From 2411d435ca37867fb6db7888b8e86035f753441c Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 11:15:27 +0000 Subject: [PATCH 05/18] Moved scripts to tools llama3 folder --- .../llama3/convert_hf_to_nanotron.py | 0 .../llama3/convert_nanotron_to_hf.py | 0 .../llama3/generate_hf_predictions.py | 3 +++ .../llama3/generate_nanotron_predictions.py | 8 ++++++-- 4 files changed, 9 insertions(+), 2 deletions(-) rename convert_hf_to_nanotron.py => tools/llama3/convert_hf_to_nanotron.py (100%) rename convert_nanotron_to_hf.py => tools/llama3/convert_nanotron_to_hf.py (100%) rename generate_hf_predictions.py => tools/llama3/generate_hf_predictions.py (94%) rename generate_nanotron_predictions.py => tools/llama3/generate_nanotron_predictions.py (94%) diff --git a/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py similarity index 100% rename from convert_hf_to_nanotron.py rename to tools/llama3/convert_hf_to_nanotron.py diff --git a/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py similarity index 100% rename from convert_nanotron_to_hf.py rename to tools/llama3/convert_nanotron_to_hf.py diff --git a/generate_hf_predictions.py b/tools/llama3/generate_hf_predictions.py similarity index 94% rename from generate_hf_predictions.py rename to tools/llama3/generate_hf_predictions.py index 5fd5bc3f..d484bff5 100644 --- a/generate_hf_predictions.py +++ b/tools/llama3/generate_hf_predictions.py @@ -1,3 +1,6 @@ +""" +torchrun --nproc-per-node 1 generate_hf_predictions.py --pretrained-model-name-or-path hf_c/second --tokenizer-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +""" import argparse import os diff --git a/generate_nanotron_predictions.py b/tools/llama3/generate_nanotron_predictions.py similarity index 94% rename from generate_nanotron_predictions.py rename to tools/llama3/generate_nanotron_predictions.py index ff613f4d..3671007d 100644 --- a/generate_nanotron_predictions.py +++ b/tools/llama3/generate_nanotron_predictions.py @@ -1,5 +1,9 @@ +""" +torchrun --nproc-per-node 1 generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path n_c/second --tokenizer-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +""" import argparse import os +from pathlib import Path import torch from nanotron.config import Config, ParallelismArgs, get_config_from_file @@ -88,7 +92,7 @@ def main(args): sanity_check(root_module=model) # Load checkpoint directly in memory and then only keep the state dictionary - load_weights(model=model, parallel_context=parallel_context, root_folder=args.nanotron_checkpoint_path) + load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") @@ -97,7 +101,7 @@ def main(args): model.eval() with torch.no_grad(): - output = model(inputs) + output = model.model(**inputs) predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models term_cols = int(os.get_terminal_size().columns / 3) From 372fa02eae1e60c8c9c2fce83d939ded8aad90c4 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 11:18:28 +0000 Subject: [PATCH 06/18] Pushed FA2 mod and rope configs fix --- src/nanotron/config/models_config.py | 4 +++ src/nanotron/models/llama.py | 53 +++++++--------------------- 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index ba4559cf..2630e1d6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -47,6 +47,10 @@ class LlamaConfig: pretraining_tp: int = 1 rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None + rope_theta: float = 10000.0 + rope_interleaved: bool = ( + True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False + ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..2072a789 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -188,35 +188,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg @checkpoint_method(attr_name="checkpoint_attention") def forward( self, - query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim] - key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) - kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True + from flash_attn.flash_attn_interface import flash_attn_func # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - attn_output = flash_attn_varlen_func( + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], dropout_p=0.0, softmax_scale=softmax_scale, causal=causal, @@ -323,7 +309,9 @@ def __init__( ) # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -565,29 +553,14 @@ def forward( # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) - q_sequence_mask = sequence_mask - kv_sequence_mask = sequence_mask - kv_length = key_states.shape[1] - # [batch_size, seq_length, num_heads, d_qk] - # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) attention_output = self.attention( query_states=query_states, key_states=key_states, value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, ) attention_output = ( From b348460911ca0c24f3f8b469b8222fb8732787c9 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 12:36:14 +0000 Subject: [PATCH 07/18] Added logging --- tools/llama3/convert_hf_to_nanotron.py | 82 +++++++++++++------ tools/llama3/convert_nanotron_to_hf.py | 44 +++++++--- tools/llama3/generate_hf_predictions.py | 20 +++-- tools/llama3/generate_nanotron_predictions.py | 27 +++--- 4 files changed, 108 insertions(+), 65 deletions(-) diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py index 3b66cef4..15f42705 100644 --- a/tools/llama3/convert_hf_to_nanotron.py +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 convert_hf_to_nanotron.py --tp 1 --nanotron-checkpoint-path n_c/second --pretrained-model-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --pretrained-model-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct """ import argparse import json @@ -8,9 +8,11 @@ import torch import yaml +from nanotron import logging from nanotron.config import Config, GeneralArgs, ModelArgs, ParallelismArgs, TokenizerArgs from nanotron.config.models_config import ExistingCheckpointInit from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.logging import log_rank from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext @@ -20,12 +22,17 @@ from nanotron.serialize import TrainingMetadata, save_meta, save_weights from nanotron.serialize.metadata import DataStageMetadata from nanotron.trainer import mark_tied_parameters -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +logger = logging.get_logger(__name__) + +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism DP = 1 PP = 1 +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + def get_args(): parser = argparse.ArgumentParser() @@ -54,10 +61,36 @@ def get_args(): def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs( + dp=DP, + pp=PP, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) hf_model = AutoModelForCausalLM.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ).to("cuda") + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) hf_config = hf_model.config # Set Nanotron LlamaConfig @@ -85,25 +118,7 @@ def main(args): ) # Init Llama3-8B Nanotron model - parallel_config = ParallelismArgs( - dp=DP, - pp=PP, - tp=args.tp, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) - assert ( - parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE - and parallel_config.tp_linear_async_communication is False - ) - - parallel_context = ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_llama_config, @@ -112,14 +127,15 @@ def main(args): random_states=None, ), parallel_context=parallel_context, - dtype=torch.bfloat16, - device=torch.device("cuda"), + dtype=TORCH_DTYPE, + device=DEVICE, ) mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) sanity_check(root_module=nanotron_model) # Copy params from HF to Nanotron + log_rank("Copyng weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) # Token embeddings assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape @@ -210,11 +226,13 @@ def main(args): with torch.no_grad(): nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) # Store weights nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) training_metadata = TrainingMetadata( last_train_step=0, consumed_train_samples=0, @@ -223,6 +241,9 @@ def main(args): save_meta( root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) # Store Config and Model Config files with open(nanotron_checkpoint_path / "config.yaml", "w") as f: @@ -233,7 +254,7 @@ def main(args): init_method=ExistingCheckpointInit(nanotron_checkpoint_path), model_config=nanotron_llama_config, ), - tokenizer=TokenizerArgs(args.pretrained_model_name_or_path), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), ) print("Saving config ...") yaml.dump(config.as_dict(), f) @@ -242,6 +263,13 @@ def main(args): print("Saving model config ...") json.dump(asdict(nanotron_llama_config), f) + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + if __name__ == "__main__": _args = get_args() diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py index 363e0c51..c0bb1b1b 100644 --- a/tools/llama3/convert_nanotron_to_hf.py +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 convert_nanotron_to_hf.py --tp 1 --nanotron-checkpoint-path n_c/second --hugging-face-checkpoint-path hf_c/second +torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --hugging-face-checkpoint-path hf_checkpoints/ConvertedNanotronLlama38B """ import argparse import os @@ -7,7 +7,9 @@ from pathlib import Path import torch +from nanotron import logging from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.logging import log_rank from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext @@ -19,10 +21,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.llama import LlamaConfig as LlamaConfigHF -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +logger = logging.get_logger(__name__) + +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism DP = 1 PP = 1 +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + def get_args(): parser = argparse.ArgumentParser() @@ -52,13 +59,7 @@ def get_args(): def main(args): - # Load Nanotron checkpoint config - nanotron_config = get_config_from_file( - os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None - ) - nanotron_llama_config = nanotron_config.model.model_config - - # Init Llama3-8B Nanotron model + # Init Nanotron Parallel Utilities parallel_config = ParallelismArgs( dp=DP, pp=PP, @@ -78,6 +79,20 @@ def main(args): tensor_parallel_size=parallel_config.tp, ) + # Load Nanotron checkpoint config + log_rank( + f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", + logger=logger, + level=logging.INFO, + rank=0, + ) + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_llama_config = nanotron_config.model.model_config + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_config.model.model_config, @@ -86,8 +101,8 @@ def main(args): random_states=None, ), parallel_context=parallel_context, - dtype=torch.bfloat16, - device=torch.device("cuda"), + dtype=TORCH_DTYPE, + device=DEVICE, ) mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) @@ -102,11 +117,12 @@ def main(args): ## TODO This takes pretty long time hf_model = AutoModelForCausalLM.from_config( config=LlamaConfigHF(**asdict(nanotron_llama_config)), - torch_dtype=torch.bfloat16, + torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2", - ).to("cuda") + ).to(DEVICE) # Copy params from Nanotron to HF + log_rank("Copyng weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) # Token embeddings assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape @@ -201,7 +217,9 @@ def main(args): with torch.no_grad(): hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) + log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) # Store weights + log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) # Store tokenizer tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) diff --git a/tools/llama3/generate_hf_predictions.py b/tools/llama3/generate_hf_predictions.py index d484bff5..12b52f2a 100644 --- a/tools/llama3/generate_hf_predictions.py +++ b/tools/llama3/generate_hf_predictions.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 generate_hf_predictions.py --pretrained-model-name-or-path hf_c/second --tokenizer-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +torchrun --nproc-per-node 1 tools/llama3/generate_hf_predictions.py --pretrained-model-name-or-path hf_checkpoints/ConvertedNanotronLlama38B """ import argparse import os @@ -10,6 +10,9 @@ TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + def get_args(): parser = argparse.ArgumentParser() @@ -28,16 +31,15 @@ def get_args(): def main(args): # TODO Refractor with HF pipeline or .generate()? - model = ( - AutoModelForCausalLM.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - .to("cuda") - .eval() - ) + model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + device=DEVICE, + ).eval() tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) - tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) inputs = tokens[:, :-1] with torch.no_grad(): diff --git a/tools/llama3/generate_nanotron_predictions.py b/tools/llama3/generate_nanotron_predictions.py index 3671007d..dc77acc9 100644 --- a/tools/llama3/generate_nanotron_predictions.py +++ b/tools/llama3/generate_nanotron_predictions.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path n_c/second --tokenizer-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +torchrun --nproc-per-node 1 tools/llama3/generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B """ import argparse import os @@ -17,13 +17,16 @@ from nanotron.trainer import mark_tied_parameters from transformers import AutoTokenizer -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of parallelism +# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism DP = 1 PP = 1 TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + def get_args(): parser = argparse.ArgumentParser() @@ -38,21 +41,13 @@ def get_args(): group = parser.add_argument_group(title="Nanotron Parallelism") group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") - group = parser.add_argument_group(title="Tokenizer") - group.add_argument( - "--tokenizer-name-or-path", - type=str, - required=True, - help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.", - ) - args = parser.parse_args() return args def main(args): - + # Init Nanotron Parallel Utilities parallel_config = ParallelismArgs( dp=DP, pp=PP, @@ -84,8 +79,8 @@ def main(args): random_states=None, ), parallel_context=parallel_context, - dtype=torch.bfloat16, - device=torch.device("cuda"), # TODO Check with different parallelism + dtype=TORCH_DTYPE, + device=DEVICE, # TODO Check with different parallelism if cpu is available ) mark_tied_parameters(model=model, parallel_context=parallel_context) @@ -94,9 +89,9 @@ def main(args): # Load checkpoint directly in memory and then only keep the state dictionary load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) - tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to("cuda") - inputs = {"input_ids": tokens[:, :-1], "input_mask": torch.ones((1, SEQ_LENGTH), device="cuda")} + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokens = tokenizer(TXT, return_tensors="pt", truncation=True, max_length=(SEQ_LENGTH + 1))["input_ids"].to(DEVICE) + inputs = {"input_ids": tokens[:, :-1], "input_mask": torch.ones((1, SEQ_LENGTH), device=DEVICE)} model.eval() From de81b53a27e89f3cf5361bfa3eda6fc5078bf9db Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sun, 19 May 2024 23:44:28 +0000 Subject: [PATCH 08/18] Cleaned scripts --- tools/llama3/convert_hf_to_nanotron.py | 53 ++++++---------- tools/llama3/convert_nanotron_to_hf.py | 53 ++++++++-------- tools/llama3/generate_hf_predictions.py | 23 +++++-- tools/llama3/generate_nanotron_predictions.py | 61 +++++++++++-------- 4 files changed, 98 insertions(+), 92 deletions(-) diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py index 15f42705..4c185f01 100644 --- a/tools/llama3/convert_hf_to_nanotron.py +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --pretrained-model-name-or-path /mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct +torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct """ import argparse import json @@ -8,11 +8,9 @@ import torch import yaml -from nanotron import logging from nanotron.config import Config, GeneralArgs, ModelArgs, ParallelismArgs, TokenizerArgs from nanotron.config.models_config import ExistingCheckpointInit from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron -from nanotron.logging import log_rank from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext @@ -22,15 +20,10 @@ from nanotron.serialize import TrainingMetadata, save_meta, save_weights from nanotron.serialize.metadata import DataStageMetadata from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -logger = logging.get_logger(__name__) - -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism -DP = 1 -PP = 1 - -DEVICE = torch.device("cuda") +DEVICE = torch.device("cpu") TORCH_DTYPE = torch.bfloat16 @@ -44,9 +37,6 @@ def get_args(): help="A path to a directory to store the converted Nanotron Checkpoint", ) - group = parser.add_argument_group(title="Nanotron Parallelism") - group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") - group = parser.add_argument_group(title="HuggingFace Model") group.add_argument( "--pretrained-model-name-or-path", @@ -63,9 +53,9 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities parallel_config = ParallelismArgs( - dp=DP, - pp=PP, - tp=args.tp, + dp=1, + pp=1, + tp=1, pp_engine=AllForwardAllBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -82,12 +72,7 @@ def main(args): ) # Load Llama3-8B HF model - log_rank( - f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) + print(f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}") hf_model = AutoModelForCausalLM.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" ).to(DEVICE) @@ -118,7 +103,7 @@ def main(args): ) # Init Llama3-8B Nanotron model - log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + print("Init empty Nanotron Llama3 Model") nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_llama_config, @@ -135,8 +120,9 @@ def main(args): sanity_check(root_module=nanotron_model) # Copy params from HF to Nanotron - log_rank("Copyng weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + print("Copyng weights from HF model to Nanotron model...") # Token embeddings + print("Copyng Token Embeddings...") assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape @@ -147,7 +133,11 @@ def main(args): ) # Decoder layers - for i in range(nanotron_llama_config.num_hidden_layers): + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copyng Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): # Input layer norm assert ( hf_model.model.layers[i].input_layernorm.weight.shape @@ -217,22 +207,24 @@ def main(args): ) # Last layer norm + print("Copyng Final Layer Norm...") assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape with torch.no_grad(): nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) # LM_Head + print("Copyng LM Head...") assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape with torch.no_grad(): nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) - log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + print("Copied weights from HF model to Nanotron model!") # Store weights nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) # Store metadata - log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + print("Storing Nanotron model Configs and Metadata!") training_metadata = TrainingMetadata( last_train_step=0, consumed_train_samples=0, @@ -263,12 +255,7 @@ def main(args): print("Saving model config ...") json.dump(asdict(nanotron_llama_config), f) - log_rank( - f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) + print(f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}") if __name__ == "__main__": diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py index c0bb1b1b..cbcb73e6 100644 --- a/tools/llama3/convert_nanotron_to_hf.py +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --hugging-face-checkpoint-path hf_checkpoints/ConvertedNanotronLlama38B +torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --hugging-face-checkpoint-path hf_checkpoints/ConvertedNanotronLlama38B """ import argparse import os @@ -7,9 +7,7 @@ from pathlib import Path import torch -from nanotron import logging from nanotron.config import Config, ParallelismArgs, get_config_from_file -from nanotron.logging import log_rank from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext @@ -18,16 +16,11 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.serialize import load_weights from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.llama import LlamaConfig as LlamaConfigHF -logger = logging.get_logger(__name__) - -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism -DP = 1 -PP = 1 - -DEVICE = torch.device("cuda") +DEVICE = torch.device("cpu") TORCH_DTYPE = torch.bfloat16 @@ -41,9 +34,6 @@ def get_args(): help="A path to a directory with a Nanotron Checkpoint", ) - group = parser.add_argument_group(title="Nanotron Parallelism") - group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") - group = parser.add_argument_group(title="HuggingFace Model") group.add_argument( "--hugging-face-checkpoint-path", @@ -61,9 +51,9 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities parallel_config = ParallelismArgs( - dp=DP, - pp=PP, - tp=args.tp, + dp=1, + pp=1, + tp=1, pp_engine=AllForwardAllBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -80,19 +70,14 @@ def main(args): ) # Load Nanotron checkpoint config - log_rank( - f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", - logger=logger, - level=logging.INFO, - rank=0, - ) + print(f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}") nanotron_config = get_config_from_file( os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None ) nanotron_llama_config = nanotron_config.model.model_config # Init Llama3-8B Nanotron model - log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + print("Init empty Nanotron Llama3 Model") nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_config.model.model_config, @@ -109,21 +94,23 @@ def main(args): sanity_check(root_module=nanotron_model) # Load Nanotron Checkpoint + print("Loading Nanotron Llama3 Model...") load_weights( model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) ) # Build empty HF Model - ## TODO This takes pretty long time - hf_model = AutoModelForCausalLM.from_config( + print("Init empty HF Llama3 Model") + hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time config=LlamaConfigHF(**asdict(nanotron_llama_config)), torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2", ).to(DEVICE) # Copy params from Nanotron to HF - log_rank("Copyng weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + print("Copyng weights from Nanotron model to HF model...") # Token embeddings + print("Copyng Token Embeddings...") assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape @@ -134,7 +121,11 @@ def main(args): ) # Decoder layers - for i in range(nanotron_config.model.model_config.num_hidden_layers): + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copyng Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): # Input layer norm assert ( hf_model.model.layers[i].input_layernorm.weight.shape @@ -208,23 +199,27 @@ def main(args): ) # Last layer norm + print("Copyng Final Layer Norm...") assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape with torch.no_grad(): hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) # LM_Head + print("Copyng LM Head...") assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape with torch.no_grad(): hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) - log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) + print("Copied weights from Nanotron model to HF model!") # Store weights - log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) + print("Storing HF model Checkpoint and Tokenizer!") hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) # Store tokenizer tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) tokenizer.save_pretrained(args.hugging_face_checkpoint_path) + print(f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}") + if __name__ == "__main__": _args = get_args() diff --git a/tools/llama3/generate_hf_predictions.py b/tools/llama3/generate_hf_predictions.py index 12b52f2a..b16774a4 100644 --- a/tools/llama3/generate_hf_predictions.py +++ b/tools/llama3/generate_hf_predictions.py @@ -1,14 +1,16 @@ """ -torchrun --nproc-per-node 1 tools/llama3/generate_hf_predictions.py --pretrained-model-name-or-path hf_checkpoints/ConvertedNanotronLlama38B +torchrun --nproc-per-node 1 tools/llama3/generate_hf_predictions.py --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct """ import argparse import os +import numpy as np import torch +from sklearn.metrics import accuracy_score from transformers import AutoModelForCausalLM, AutoTokenizer -TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." -SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens +TXT = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello! Which is the capital of France? What can I visit over there if I go for a week vacation?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris, also known as the City of Light. Paris is a stunning city with a rich history, art, fashion, and cuisine. If you're planning a week-long vacation in Paris, you'll have plenty of time to explore its iconic landmarks, museums, and neighborhoods. Here's a suggested itinerary to get you started: Day 1-2: Iconic Landmarks The Eiffel Tower (Tour Eiffel): The iron lady offers breathtaking views of the city. You can take the stairs or elevator to the top. The Louvre Museum (Musée du Louvre): Home to the Mona Lisa, Venus de Milo, and many other famous artworks. Arc de Triomphe: A monumental arch honoring the soldiers who fought and died for France. Champs-Élysées: A famous avenue lined with cafes, shops, and theaters. Day 3: Montmartre and Sacré-Cœur Explore the charming neighborhood of Montmartre, known for its bohemian vibe, street artists, and stunning views. Visit the Basilique du Sacré-Cœur, a beautiful white church perched on a hill." +SEQ_LENGTH = 512 # For truncating the TXT if GPU can't fit too many tokens DEVICE = torch.device("cuda") TORCH_DTYPE = torch.bfloat16 @@ -30,12 +32,12 @@ def get_args(): def main(args): - # TODO Refractor with HF pipeline or .generate()? + model = AutoModelForCausalLM.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2", - device=DEVICE, + device_map="auto", ).eval() tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) @@ -62,6 +64,17 @@ def main(args): sep="\n", ) + # Compute accuracy + predictions = np.argmax(output.logits.cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + # Results + ## [TP=1] HF 8B: 0.8308823529411765 + ## [TP=2]HF 70B: 0.8860294117647058 + ## [TP=1] HF -> Nanotron -> HF 8B: 0.8308823529411765 + ## [TP=2] HF -> Nanotron -> HF 70B: 0.8860294117647058 + ## [TP=1 --> TP=2] HF -> Nanotron -> Dummy Finetune to change TP=2 -> HF 8B: 0.8308823529411765 + if __name__ == "__main__": _args = get_args() diff --git a/tools/llama3/generate_nanotron_predictions.py b/tools/llama3/generate_nanotron_predictions.py index dc77acc9..fbede799 100644 --- a/tools/llama3/generate_nanotron_predictions.py +++ b/tools/llama3/generate_nanotron_predictions.py @@ -1,10 +1,12 @@ """ -torchrun --nproc-per-node 1 tools/llama3/generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B +torchrun --nproc-per-node 2 tools/llama3/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B """ import argparse import os from pathlib import Path +import nanotron.distributed as dist +import numpy as np import torch from nanotron.config import Config, ParallelismArgs, get_config_from_file from nanotron.models import build_model @@ -15,14 +17,11 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.serialize import load_weights from nanotron.trainer import mark_tied_parameters +from sklearn.metrics import accuracy_score from transformers import AutoTokenizer -# TODO Currentyly just sopporting Llama8B that doesn't needs any kind of model parallelism -DP = 1 -PP = 1 - -TXT = "The prologue of Romeo and Juliet calls the title characters “star-crossed lovers”—and the stars do seem to conspire against these young lovers. Romeo is a Montague, and Juliet a Capulet. Their families are enmeshed in a feud, but the moment they meet—when Romeo and his friends attend a party at Juliets house in disguise—the two fall in love and quickly decide that they want to be married. A friar secretly marries them, hoping to end the feud. Romeo and his companions almost immediately encounter Juliets cousin Tybalt, who challenges Romeo. When Romeo refuses to fight, Romeos friend Mercutio accepts the challenge and is killed. Romeo then kills Tybalt and is banished. He spends that night with Juliet and then leaves for Mantua. Juliets father forces her into a marriage with Count Paris. To avoid this marriage, Juliet takes a potion, given her by the friar, that makes her appear dead. The friar will send Romeo word to be at her family tomb when she awakes. The plan goes awry, and Romeo learns instead that she is dead. In the tomb, Romeo kills himself. Juliet wakes, sees his body, and commits suicide. Their deaths appear finally to end the feud." -SEQ_LENGTH = 256 # For truncating the TXT if GPU can't fit too many tokens +TXT = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello! Which is the capital of France? What can I visit over there if I go for a week vacation?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris, also known as the City of Light. Paris is a stunning city with a rich history, art, fashion, and cuisine. If you're planning a week-long vacation in Paris, you'll have plenty of time to explore its iconic landmarks, museums, and neighborhoods. Here's a suggested itinerary to get you started: Day 1-2: Iconic Landmarks The Eiffel Tower (Tour Eiffel): The iron lady offers breathtaking views of the city. You can take the stairs or elevator to the top. The Louvre Museum (Musée du Louvre): Home to the Mona Lisa, Venus de Milo, and many other famous artworks. Arc de Triomphe: A monumental arch honoring the soldiers who fought and died for France. Champs-Élysées: A famous avenue lined with cafes, shops, and theaters. Day 3: Montmartre and Sacré-Cœur Explore the charming neighborhood of Montmartre, known for its bohemian vibe, street artists, and stunning views. Visit the Basilique du Sacré-Cœur, a beautiful white church perched on a hill." +SEQ_LENGTH = 512 # For truncating the TXT if GPU can't fit too many tokens DEVICE = torch.device("cuda") TORCH_DTYPE = torch.bfloat16 @@ -49,8 +48,8 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities parallel_config = ParallelismArgs( - dp=DP, - pp=PP, + dp=1, + pp=1, tp=args.tp, pp_engine=AllForwardAllBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, @@ -67,6 +66,8 @@ def main(args): tensor_parallel_size=parallel_config.tp, ) + RANK = dist.get_rank(parallel_context.world_pg) + nanotron_config = get_config_from_file( os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None ) @@ -98,22 +99,32 @@ def main(args): with torch.no_grad(): output = model.model(**inputs) - predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models - term_cols = int(os.get_terminal_size().columns / 3) - - for predicted_token in predicted_tokens: - - print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) - next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) - topk_next_tokens = torch.topk(next_tokens, 10) - - print( - *[ - f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" - for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) - ], - sep="\n", - ) + if not RANK: + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + predictions = np.argmax(output.transpose(0, 1).cpu(), axis=2).flatten().tolist() + labels = tokens.cpu().flatten()[1:].tolist() + print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + # Results + ## Nanotron 8B, TP 1: 0.8272058823529411 + ## Nanotron 8B, TP 2: 0.7720588235294118 + ## Nanotron 70B, TP 2: 0.8272058823529411 if __name__ == "__main__": From a28c53289950adcaa0f6fe2914c762921929d66e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 22 May 2024 13:01:01 +0000 Subject: [PATCH 09/18] Added Nanotron logging --- tools/llama3/convert_hf_to_nanotron.py | 57 +++++++++++++------------ tools/llama3/convert_nanotron_to_hf.py | 59 ++++++++++++++------------ 2 files changed, 61 insertions(+), 55 deletions(-) diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py index 4c185f01..0032bf9a 100644 --- a/tools/llama3/convert_hf_to_nanotron.py +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -8,21 +8,23 @@ import torch import yaml -from nanotron.config import Config, GeneralArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs from nanotron.config.models_config import ExistingCheckpointInit from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import sanity_check -from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine -from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.serialize import TrainingMetadata, save_meta, save_weights from nanotron.serialize.metadata import DataStageMetadata from nanotron.trainer import mark_tied_parameters from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer +logger = logging.get_logger(__name__) + DEVICE = torch.device("cpu") TORCH_DTYPE = torch.bfloat16 @@ -52,18 +54,7 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities - parallel_config = ParallelismArgs( - dp=1, - pp=1, - tp=1, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) - assert ( - parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE - and parallel_config.tp_linear_async_communication is False - ) + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, @@ -71,8 +62,15 @@ def main(args): tensor_parallel_size=parallel_config.tp, ) + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + # Load Llama3-8B HF model - print(f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}") + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) hf_model = AutoModelForCausalLM.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" ).to(DEVICE) @@ -103,7 +101,7 @@ def main(args): ) # Init Llama3-8B Nanotron model - print("Init empty Nanotron Llama3 Model") + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_llama_config, @@ -120,9 +118,9 @@ def main(args): sanity_check(root_module=nanotron_model) # Copy params from HF to Nanotron - print("Copyng weights from HF model to Nanotron model...") + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) # Token embeddings - print("Copyng Token Embeddings...") + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape @@ -135,7 +133,7 @@ def main(args): # Decoder layers for i in tqdm( range(nanotron_llama_config.num_hidden_layers), - desc="Copyng Hidden Layers", + desc="Copying Hidden Layers", total=nanotron_llama_config.num_hidden_layers, ): # Input layer norm @@ -207,24 +205,24 @@ def main(args): ) # Last layer norm - print("Copyng Final Layer Norm...") + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape with torch.no_grad(): nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) # LM_Head - print("Copyng LM Head...") + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape with torch.no_grad(): nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) - print("Copied weights from HF model to Nanotron model!") + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) # Store weights nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) # Store metadata - print("Storing Nanotron model Configs and Metadata!") + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) training_metadata = TrainingMetadata( last_train_step=0, consumed_train_samples=0, @@ -248,14 +246,19 @@ def main(args): ), tokenizer=TokenizerArgs(nanotron_checkpoint_path), ) - print("Saving config ...") + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) yaml.dump(config.as_dict(), f) with open(nanotron_checkpoint_path / "model_config.json", "w") as f: - print("Saving model config ...") + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) json.dump(asdict(nanotron_llama_config), f) - print(f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}") + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) if __name__ == "__main__": diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py index cbcb73e6..0254ed4a 100644 --- a/tools/llama3/convert_nanotron_to_hf.py +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -7,19 +7,21 @@ from pathlib import Path import torch -from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron import logging +from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file +from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import sanity_check -from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine -from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.serialize import load_weights from nanotron.trainer import mark_tied_parameters from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.llama import LlamaConfig as LlamaConfigHF +logger = logging.get_logger(__name__) + DEVICE = torch.device("cpu") TORCH_DTYPE = torch.bfloat16 @@ -41,7 +43,6 @@ def get_args(): required=True, help="A path to a directory to store the converted checkpoint", ) - # TODO Add push to hub args = parser.parse_args() @@ -50,18 +51,7 @@ def get_args(): def main(args): # Init Nanotron Parallel Utilities - parallel_config = ParallelismArgs( - dp=1, - pp=1, - tp=1, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) - assert ( - parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE - and parallel_config.tp_linear_async_communication is False - ) + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, @@ -69,15 +59,23 @@ def main(args): tensor_parallel_size=parallel_config.tp, ) + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + # Load Nanotron checkpoint config - print(f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}") + log_rank( + f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", + logger=logger, + level=logging.INFO, + rank=0, + ) nanotron_config = get_config_from_file( os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None ) nanotron_llama_config = nanotron_config.model.model_config # Init Llama3-8B Nanotron model - print("Init empty Nanotron Llama3 Model") + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( model_builder=lambda: LlamaForTraining( config=nanotron_config.model.model_config, @@ -94,13 +92,13 @@ def main(args): sanity_check(root_module=nanotron_model) # Load Nanotron Checkpoint - print("Loading Nanotron Llama3 Model...") + log_rank("Loading Nanotron Llama3 Model...", logger=logger, level=logging.INFO, rank=0) load_weights( model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) ) # Build empty HF Model - print("Init empty HF Llama3 Model") + log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0) hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time config=LlamaConfigHF(**asdict(nanotron_llama_config)), torch_dtype=TORCH_DTYPE, @@ -108,9 +106,9 @@ def main(args): ).to(DEVICE) # Copy params from Nanotron to HF - print("Copyng weights from Nanotron model to HF model...") + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) # Token embeddings - print("Copyng Token Embeddings...") + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) assert ( nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape @@ -123,7 +121,7 @@ def main(args): # Decoder layers for i in tqdm( range(nanotron_llama_config.num_hidden_layers), - desc="Copyng Hidden Layers", + desc="Copying Hidden Layers", total=nanotron_llama_config.num_hidden_layers, ): # Input layer norm @@ -199,26 +197,31 @@ def main(args): ) # Last layer norm - print("Copyng Final Layer Norm...") + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape with torch.no_grad(): hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) # LM_Head - print("Copyng LM Head...") + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape with torch.no_grad(): hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) - print("Copied weights from Nanotron model to HF model!") + log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) # Store weights - print("Storing HF model Checkpoint and Tokenizer!") + log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) # Store tokenizer tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) tokenizer.save_pretrained(args.hugging_face_checkpoint_path) - print(f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}") + log_rank( + f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) if __name__ == "__main__": From 3e169c5afc80b3494c1065bd4ef3079dc2b657de Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 22 May 2024 13:44:01 +0000 Subject: [PATCH 10/18] Added README --- tools/llama3/README.md | 19 +++++++++++++++++++ tools/llama3/convert_hf_to_nanotron.py | 4 ++-- tools/llama3/convert_nanotron_to_hf.py | 2 +- 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 tools/llama3/README.md diff --git a/tools/llama3/README.md b/tools/llama3/README.md new file mode 100644 index 00000000..57a31b5e --- /dev/null +++ b/tools/llama3/README.md @@ -0,0 +1,19 @@ +# Llama3 Weight conversion tool +This directory contains the scripts to convert the Llama3 checkpoints from HuggingFace to Nanotron and vice versa. + +- Convert from HuggingFace to Nanotron + +`torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct` +- Convert from Nanotron to HuggingFace + +`torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B` + +In summary, we will do the following: +- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py). +- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/llama.py). +- Copy the parameters layer by layer from one model to the other. +- Store the Nanotron model along with the tokenizer. + +When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated [[1]](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L415), [[2]](https://github.com/huggingface/transformers/blob/1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4/src/transformers/models/llama/modeling_llama.py#L194), while in the Nanotron implementation, they are concatenated [[1b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L310), [[2b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L149). It is crucial to pay attention to these details to convert the models correctly. + +To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic. diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py index 0032bf9a..e30610a3 100644 --- a/tools/llama3/convert_hf_to_nanotron.py +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct +torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct """ import argparse import json @@ -238,7 +238,7 @@ def main(args): # Store Config and Model Config files with open(nanotron_checkpoint_path / "config.yaml", "w") as f: config = Config( - general=GeneralArgs(project="conversion", run="Llama3-8B"), + general=GeneralArgs(project="Nanotron", run="Llama3"), parallelism=parallel_config, model=ModelArgs( init_method=ExistingCheckpointInit(nanotron_checkpoint_path), diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py index 0254ed4a..c5fb1940 100644 --- a/tools/llama3/convert_nanotron_to_hf.py +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B --hugging-face-checkpoint-path hf_checkpoints/ConvertedNanotronLlama38B +torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B """ import argparse import os From 31c12e86f83052a85caedcf978ae6373ede43cda Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 13 Jun 2024 16:19:57 +0000 Subject: [PATCH 11/18] add LlamaRotary because generation is not good otherwise --- src/nanotron/models/llama.py | 98 +++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 6 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2072a789..c2c07614 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -117,6 +117,73 @@ def forward( return x_out.type(dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +### llama +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim: int, end: int, theta: float = 500000.0): + super().__init__() + self.dim = dim + self.end = end + self.theta = theta + self.init_rotary_embeddings() + + def init_rotary_embeddings(self): + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, + x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] + position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] + ): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("rotary") + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -303,10 +370,17 @@ def __init__( contiguous_chunks=qkv_contiguous_chunks, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - ) + if config.rope_interleaved: + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + ) + else: + self.rotary_embedding = LlamaRotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + ) + self.rope_interleaved = config.rope_interleaved # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) self.flash_rotary_embedding = FlashRotaryEmbedding( @@ -336,6 +410,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + position_ids: Optional[torch.LongTensor] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -390,8 +465,19 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + + # Rotate half rotary_embedding + # cos, sin = self.rotary_embedding(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # interleaved + if self.rope_interleaved: + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # llama rotary position embedding + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if "key" not in store: # First inference iteration (Prefill) From c0c74aaa890a776ce4d0182d9b89544520fe41d1 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 10:28:55 +0000 Subject: [PATCH 12/18] generate refacto is now working --- run_generate.py | 211 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 166 insertions(+), 45 deletions(-) diff --git a/run_generate.py b/run_generate.py index f389770d..689f428d 100644 --- a/run_generate.py +++ b/run_generate.py @@ -7,7 +7,6 @@ torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4 ``` """ - import argparse import os from pathlib import Path @@ -23,6 +22,7 @@ ) from nanotron.generation.decode import ( GenerationInput, + GenerationInputs, TokenizerConfig, decode_text, decode_tokenized, @@ -50,6 +50,9 @@ except ImportError: AutoTokenizer = None +from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState + logger = logging.get_logger(__name__) @@ -57,8 +60,8 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") parser.add_argument("--dp", type=int, default=1) - parser.add_argument("--pp", type=int, default=0) - parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") return parser.parse_args() @@ -73,9 +76,9 @@ def main(): tokenizer_path = config.tokenizer.tokenizer_name_or_path parallel_config = ParallelismArgs( - dp=args.dp or config.parallelism.dp, - pp=args.pp or config.parallelism.pp, - tp=args.tp or config.parallelism.tp, + dp=args.dp, + pp=args.pp, + tp=args.tp, pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -166,52 +169,57 @@ def main(): dummy_inputs = [ "The future of AI is", # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", + # "def fib(n)", # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', # "Advancements in technology will lead to", # "Tomorrow's world is shaped by", ] - outputs = decode_text( - input_iter=(GenerationInput(text=text) for text in dummy_inputs), - tokenizer=tokenizer, - # TODO @thomasw21: From ModelWithLoss extract the model. - model=model.model, - parallel_context=parallel_context, - max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - tokenizer_config=TokenizerConfig(max_input_length=None), - is_bench=os.environ.get("USE_BENCH", "0") == "1", - ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - - log_rank( - f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", - logger=logger, - level=logging.INFO, - rank=0, - ) - - log_rank( - f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", - logger=logger, - level=logging.INFO, - rank=0, + if os.environ.get("REFACTO", "0") == "1": + refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs) + # print("==================================================") + else: + outputs = decode_text( + input_iter=(GenerationInput(text=text) for text in dummy_inputs), + tokenizer=tokenizer, + # TODO @thomasw21: From ModelWithLoss extract the model. + model=model.model, + parallel_context=parallel_context, + max_new_tokens=args.max_new_tokens, + max_micro_batch_size=2, + generation_config=GenerationArgs(sampler="greedy", use_cache=False), + tokenizer_config=TokenizerConfig(max_input_length=None), + is_bench=os.environ.get("USE_BENCH", "0") == "1", ) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, - ) + for output in outputs: + input_ids = output.input_ids + generated_ids = output.generation_ids + if isinstance(input_ids, TensorPointer): + assert isinstance(generated_ids, TensorPointer) + continue + assert isinstance(generated_ids, torch.Tensor) + + log_rank( + f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + log_rank( + f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + log_rank( + "--------------------------------------------------", + logger=logger, + level=logging.INFO, + rank=0, + ) else: outputs = decode_tokenized( input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), @@ -247,5 +255,118 @@ def main(): dist.barrier() +def run_one_inference_step(model, batch, parallel_context, device): + if dist.get_world_size(group=parallel_context.pp_pg) == 1: + return model.model(batch.input_ids, batch.input_masks) + + pipeline_state = PipelineEvalBatchState() + with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): + + batch_size = batch.input_ids.shape[0] + seq_len = batch.input_ids.shape[1] + + # Preallocate memory for output logits. + logits = None + if parallel_context.is_pipeline_last_stage: + logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device) + + batch2use = GenerationInputs( + input_ids=batch.input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + input_masks=batch.input_masks + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + ) + + output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) + + # TODO: Check if we need to send only 2 + nb_send = len(pipeline_state.microbatches_activations_to_send) + assert nb_send <= 2 + for _ in range(nb_send): + pipeline_state.run_communication() + + # Copy logits. + if parallel_context.is_pipeline_last_stage: + logits = output_tensor + + # Wait for all the communication to complete. + dist.barrier(group=parallel_context.pp_pg) + + return logits + + +def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs): + device = torch.cuda.current_device() + tokenized_prompts = tokenizer( + dummy_inputs, + return_tensors="pt", + return_attention_mask=True, + padding=True, + ) + + tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) + tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device) + + for _ in range(args.max_new_tokens): + batch_prompts = GenerationInputs( + input_ids=tokenized_prompts["input_ids"], + input_masks=tokenized_prompts["attention_mask"], + ) + + logits = run_one_inference_step(model, batch_prompts, parallel_context, device) + + # Sample new token + if parallel_context.is_pipeline_last_stage: + assert logits is not None + # TODO(fmom): dont transpose if it is mamba. Add if "logits_are_batch_first" flag + logits = logits.transpose(0, 1) + # TODO: Choose between more sampler + next_token = torch.argmax(logits[:, -1], dim=-1) + tokenized_prompts["input_ids"] = torch.cat( + [tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1 + ) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device), + ], + dim=-1, + ) + else: + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + (tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + + dist.broadcast( + tokenized_prompts["input_ids"], + src=parallel_context.pipeline_parallel_last_rank, + group=parallel_context.pp_pg, + ) + dist.broadcast( + tokenized_prompts["attention_mask"], + src=parallel_context.pipeline_parallel_last_rank, + group=parallel_context.pp_pg, + ) + + if parallel_context.is_pipeline_last_stage: + for i, prompt in enumerate(dummy_inputs): + tokenized_outputs = tokenized_prompts["input_ids"][ + i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : + ] + outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) + + print(f"Input: {prompt}") + print(f"Output: {outputs}") + + if __name__ == "__main__": main() From 1503d9e58a37eabc9cdddc42f53d3417e2224714 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 13:31:02 +0000 Subject: [PATCH 13/18] fix bug with multiple process group + add sampler --- run_generate.py | 69 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/run_generate.py b/run_generate.py index 689f428d..aac09155 100644 --- a/run_generate.py +++ b/run_generate.py @@ -20,6 +20,7 @@ ParallelismArgs, get_config_from_file, ) +from nanotron.distributed import get_global_rank from nanotron.generation.decode import ( GenerationInput, GenerationInputs, @@ -27,6 +28,7 @@ decode_text, decode_tokenized, ) +from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model from nanotron.parallel import ParallelContext @@ -50,9 +52,12 @@ except ImportError: AutoTokenizer = None +import lovely_tensors as lt from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState +lt.monkey_patch() + logger = logging.get_logger(__name__) @@ -176,7 +181,14 @@ def main(): ] if os.environ.get("REFACTO", "0") == "1": - refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs) + refactor_decode_text( + args, + parallel_context, + model, + tokenizer, + dummy_inputs, + generation_config=GenerationArgs(sampler="greedy", use_cache=False), + ) # print("==================================================") else: outputs = decode_text( @@ -282,6 +294,7 @@ def run_one_inference_step(model, batch, parallel_context, device): output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) # TODO: Check if we need to send only 2 + nb_send = len(pipeline_state.microbatches_activations_to_send) assert nb_send <= 2 for _ in range(nb_send): @@ -292,13 +305,26 @@ def run_one_inference_step(model, batch, parallel_context, device): logits = output_tensor # Wait for all the communication to complete. - dist.barrier(group=parallel_context.pp_pg) + dist.barrier(group=parallel_context.world_pg) return logits -def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs): +def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs, generation_config): device = torch.cuda.current_device() + + if generation_config: + if isinstance(generation_config.sampler, str): + sampler_type = SamplerType(generation_config.sampler.upper()) + else: + sampler_type = generation_config.sampler + else: + sampler_type = SamplerType.GREEDY + + # TODO: add batch inference + # TODO: add decoded_tokenize + # TODO: add benchmark + tokenized_prompts = tokenizer( dummy_inputs, return_tensors="pt", @@ -319,14 +345,26 @@ def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs) # Sample new token if parallel_context.is_pipeline_last_stage: - assert logits is not None + assert logits is not None and isinstance(logits, torch.Tensor) + # TODO(fmom): dont transpose if it is mamba. Add if "logits_are_batch_first" flag logits = logits.transpose(0, 1) - # TODO: Choose between more sampler - next_token = torch.argmax(logits[:, -1], dim=-1) - tokenized_prompts["input_ids"] = torch.cat( - [tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1 - ) + + # TODO: Use cache + if sampler_type == SamplerType.GREEDY: + sampler = GreedySampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_K: + sampler = TopKSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_P: + sampler = TopPSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.BASIC: + sampler = BasicSampler(pg=parallel_context.tp_pg) + else: + raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") + + next_token = sampler(sharded_logits=logits[:, -1]) + + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) tokenized_prompts["attention_mask"] = torch.cat( [ tokenized_prompts["attention_mask"], @@ -335,6 +373,7 @@ def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs) dim=-1, ) else: + # Extend the tokenized prompts to receive the new token tokenized_prompts["input_ids"] = torch.zeros( (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, @@ -346,26 +385,28 @@ def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs) device=device, ) + # Broadcast the new token to all the pipeline stages dist.broadcast( tokenized_prompts["input_ids"], - src=parallel_context.pipeline_parallel_last_rank, + src=get_global_rank(group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank), group=parallel_context.pp_pg, ) dist.broadcast( tokenized_prompts["attention_mask"], - src=parallel_context.pipeline_parallel_last_rank, + src=get_global_rank(group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank), group=parallel_context.pp_pg, ) - if parallel_context.is_pipeline_last_stage: + if dist.get_rank() == 0: for i, prompt in enumerate(dummy_inputs): tokenized_outputs = tokenized_prompts["input_ids"][ i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : ] outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) - print(f"Input: {prompt}") - print(f"Output: {outputs}") + # Convert with log_rank + log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) if __name__ == "__main__": From 76ac8ca2c0c3fa6cac78c422358fcfa11fdd8b76 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 15:27:03 +0000 Subject: [PATCH 14/18] add changes to support cache --- run_generate.py | 348 ++++++++++++++++++++++-------------------------- 1 file changed, 159 insertions(+), 189 deletions(-) diff --git a/run_generate.py b/run_generate.py index aac09155..0fa42cd5 100644 --- a/run_generate.py +++ b/run_generate.py @@ -24,10 +24,12 @@ from nanotron.generation.decode import ( GenerationInput, GenerationInputs, + GenerationStates, TokenizerConfig, decode_text, - decode_tokenized, + run_one_inference_step, ) +from nanotron.generation.generate_store import Store from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model @@ -52,11 +54,6 @@ except ImportError: AutoTokenizer = None -import lovely_tensors as lt -from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState - -lt.monkey_patch() logger = logging.get_logger(__name__) @@ -68,6 +65,7 @@ def get_args(): parser.add_argument("--pp", type=int, default=1) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") + parser.add_argument("--use-cache", action="store_true", help="Use cache for generation") return parser.parse_args() @@ -171,6 +169,7 @@ def main(): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.padding_side = "left" tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + dummy_inputs = [ "The future of AI is", # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", @@ -180,16 +179,161 @@ def main(): # "Tomorrow's world is shaped by", ] + log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0) + + # This doesn't support micro-batches and batch inference yet if os.environ.get("REFACTO", "0") == "1": - refactor_decode_text( - args, - parallel_context, - model, - tokenizer, + + device = torch.cuda.current_device() + generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache) + logits_are_batch_first = True + + if generation_config: + if isinstance(generation_config.sampler, str): + sampler_type = SamplerType(generation_config.sampler.upper()) + else: + sampler_type = generation_config.sampler + else: + sampler_type = SamplerType.GREEDY + + tokenized_prompts = tokenizer( dummy_inputs, - generation_config=GenerationArgs(sampler="greedy", use_cache=False), + return_tensors="pt", + return_attention_mask=True, + padding=True, + ) + tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) + tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to( + dtype=torch.bool, device=device ) - # print("==================================================") + + store = Store() + batch_prompts = None + + for i in range(args.max_new_tokens): + + if generation_config.use_cache: + + batch_prompts = GenerationStates( + new_input_ids=tokenized_prompts["input_ids"], + new_input_mask=tokenized_prompts["attention_mask"], + store=store, + generation_ids=[tokenized_prompts["input_ids"]] + if i == 0 + else batch_prompts.generation_ids + [tokenized_prompts["input_ids"]], + generation_mask=[tokenized_prompts["attention_mask"]] + if i == 0 + else batch_prompts.generation_mask + [tokenized_prompts["attention_mask"]], + ) + else: + batch_prompts = GenerationInputs( + input_ids=tokenized_prompts["input_ids"], + input_masks=tokenized_prompts["attention_mask"], + ) + + logits = run_one_inference_step( + model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store + ) + + # Sample new token + if parallel_context.is_pipeline_last_stage: + assert logits is not None and isinstance(logits, torch.Tensor) + + # Get sampler + if sampler_type == SamplerType.GREEDY: + sampler = GreedySampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_K: + sampler = TopKSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_P: + sampler = TopPSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.BASIC: + sampler = BasicSampler(pg=parallel_context.tp_pg) + else: + raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") + + if logits_are_batch_first: + logits = logits.transpose(0, 1) + + # Predict next token + next_token = sampler(sharded_logits=logits[:, -1]) + + if generation_config.use_cache: + tokenized_prompts["input_ids"] = next_token + tokenized_prompts["attention_mask"] = torch.ones( + (next_token.shape[0], 1), dtype=torch.bool, device=device + ) + else: + tokenized_prompts["input_ids"] = torch.cat( + [tokenized_prompts["input_ids"], next_token], dim=-1 + ) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones( + (tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device + ), + ], + dim=-1, + ) + else: + # Extend the tokenized prompts to receive the new token + if generation_config.use_cache: + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + (tokenized_prompts["attention_mask"].shape[0], 1), + dtype=torch.bool, + device=device, + ) + else: + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + ( + tokenized_prompts["attention_mask"].shape[0], + tokenized_prompts["attention_mask"].shape[1] + 1, + ), + dtype=torch.bool, + device=device, + ) + + # Broadcast the new token to all the pipeline stages + dist.broadcast( + tokenized_prompts["input_ids"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, + ) + dist.broadcast( + tokenized_prompts["attention_mask"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, + ) + + if dist.get_rank() == 0: + for i, prompt in enumerate(dummy_inputs): + if generation_config.use_cache: + tokenized_outputs = torch.cat( + [tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1 + ) + outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False) + else: + tokenized_outputs = tokenized_prompts["input_ids"][ + i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : + ] + outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) + + log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) else: outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), @@ -198,8 +342,8 @@ def main(): model=model.model, parallel_context=parallel_context, max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, - generation_config=GenerationArgs(sampler="greedy", use_cache=False), + max_micro_batch_size=1, + generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), tokenizer_config=TokenizerConfig(max_input_length=None), is_bench=os.environ.get("USE_BENCH", "0") == "1", ) @@ -232,182 +376,8 @@ def main(): level=logging.INFO, rank=0, ) - else: - outputs = decode_tokenized( - input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), - input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), - model=model.model, - parallel_context=parallel_context, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - max_micro_batch_size=1, - max_new_tokens=12, - returns_logits=False, - ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - log_rank( - f"generation: {generated_ids[len(input_ids) :]}", - logger=logger, - level=logging.INFO, - rank=0, - ) - - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, - ) - dist.barrier() -def run_one_inference_step(model, batch, parallel_context, device): - if dist.get_world_size(group=parallel_context.pp_pg) == 1: - return model.model(batch.input_ids, batch.input_masks) - - pipeline_state = PipelineEvalBatchState() - with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): - - batch_size = batch.input_ids.shape[0] - seq_len = batch.input_ids.shape[1] - - # Preallocate memory for output logits. - logits = None - if parallel_context.is_pipeline_last_stage: - logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device) - - batch2use = GenerationInputs( - input_ids=batch.input_ids - if parallel_context.is_pipeline_first_stage - else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), - input_masks=batch.input_masks - if parallel_context.is_pipeline_first_stage - else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), - ) - - output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) - - # TODO: Check if we need to send only 2 - - nb_send = len(pipeline_state.microbatches_activations_to_send) - assert nb_send <= 2 - for _ in range(nb_send): - pipeline_state.run_communication() - - # Copy logits. - if parallel_context.is_pipeline_last_stage: - logits = output_tensor - - # Wait for all the communication to complete. - dist.barrier(group=parallel_context.world_pg) - - return logits - - -def refactor_decode_text(args, parallel_context, model, tokenizer, dummy_inputs, generation_config): - device = torch.cuda.current_device() - - if generation_config: - if isinstance(generation_config.sampler, str): - sampler_type = SamplerType(generation_config.sampler.upper()) - else: - sampler_type = generation_config.sampler - else: - sampler_type = SamplerType.GREEDY - - # TODO: add batch inference - # TODO: add decoded_tokenize - # TODO: add benchmark - - tokenized_prompts = tokenizer( - dummy_inputs, - return_tensors="pt", - return_attention_mask=True, - padding=True, - ) - - tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) - tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device) - - for _ in range(args.max_new_tokens): - batch_prompts = GenerationInputs( - input_ids=tokenized_prompts["input_ids"], - input_masks=tokenized_prompts["attention_mask"], - ) - - logits = run_one_inference_step(model, batch_prompts, parallel_context, device) - - # Sample new token - if parallel_context.is_pipeline_last_stage: - assert logits is not None and isinstance(logits, torch.Tensor) - - # TODO(fmom): dont transpose if it is mamba. Add if "logits_are_batch_first" flag - logits = logits.transpose(0, 1) - - # TODO: Use cache - if sampler_type == SamplerType.GREEDY: - sampler = GreedySampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.TOP_K: - sampler = TopKSampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.TOP_P: - sampler = TopPSampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.BASIC: - sampler = BasicSampler(pg=parallel_context.tp_pg) - else: - raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") - - next_token = sampler(sharded_logits=logits[:, -1]) - - tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) - tokenized_prompts["attention_mask"] = torch.cat( - [ - tokenized_prompts["attention_mask"], - torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device), - ], - dim=-1, - ) - else: - # Extend the tokenized prompts to receive the new token - tokenized_prompts["input_ids"] = torch.zeros( - (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), - dtype=torch.int64, - device=device, - ) - tokenized_prompts["attention_mask"] = torch.zeros( - (tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), - dtype=torch.int64, - device=device, - ) - - # Broadcast the new token to all the pipeline stages - dist.broadcast( - tokenized_prompts["input_ids"], - src=get_global_rank(group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank), - group=parallel_context.pp_pg, - ) - dist.broadcast( - tokenized_prompts["attention_mask"], - src=get_global_rank(group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank), - group=parallel_context.pp_pg, - ) - - if dist.get_rank() == 0: - for i, prompt in enumerate(dummy_inputs): - tokenized_outputs = tokenized_prompts["input_ids"][ - i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : - ] - outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) - - # Convert with log_rank - log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) - log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) - - if __name__ == "__main__": main() From bf6846eed76c52fcb8af23c820099d76eb38d043 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 15:37:26 +0000 Subject: [PATCH 15/18] refacto use_cache to unify with no_cache --- run_generate.py | 86 +++++++++++++++++++------------------------------ 1 file changed, 34 insertions(+), 52 deletions(-) diff --git a/run_generate.py b/run_generate.py index 0fa42cd5..1971b0e9 100644 --- a/run_generate.py +++ b/run_generate.py @@ -181,7 +181,8 @@ def main(): log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0) - # This doesn't support micro-batches and batch inference yet + # NOTE: This doesn't support micro-batches and batch inference yet + if os.environ.get("REFACTO", "0") == "1": device = torch.cuda.current_device() @@ -213,17 +214,17 @@ def main(): for i in range(args.max_new_tokens): if generation_config.use_cache: - + # Prepare the batch prompts batch_prompts = GenerationStates( - new_input_ids=tokenized_prompts["input_ids"], - new_input_mask=tokenized_prompts["attention_mask"], - store=store, - generation_ids=[tokenized_prompts["input_ids"]] + new_input_ids=tokenized_prompts["input_ids"] if i == 0 - else batch_prompts.generation_ids + [tokenized_prompts["input_ids"]], - generation_mask=[tokenized_prompts["attention_mask"]] + else tokenized_prompts["input_ids"][:, -1].unsqueeze(0), + new_input_mask=tokenized_prompts["attention_mask"] if i == 0 - else batch_prompts.generation_mask + [tokenized_prompts["attention_mask"]], + else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0), + store=store, + generation_ids=tokenized_prompts["input_ids"], + generation_mask=tokenized_prompts["attention_mask"], ) else: batch_prompts = GenerationInputs( @@ -257,51 +258,32 @@ def main(): # Predict next token next_token = sampler(sharded_logits=logits[:, -1]) - if generation_config.use_cache: - tokenized_prompts["input_ids"] = next_token - tokenized_prompts["attention_mask"] = torch.ones( - (next_token.shape[0], 1), dtype=torch.bool, device=device - ) - else: - tokenized_prompts["input_ids"] = torch.cat( - [tokenized_prompts["input_ids"], next_token], dim=-1 - ) - tokenized_prompts["attention_mask"] = torch.cat( - [ - tokenized_prompts["attention_mask"], - torch.ones( - (tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device - ), - ], - dim=-1, - ) + # Extend the tokenized prompts to insert the new token + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones( + (tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device + ), + ], + dim=-1, + ) else: # Extend the tokenized prompts to receive the new token - if generation_config.use_cache: - tokenized_prompts["input_ids"] = torch.zeros( - (tokenized_prompts["input_ids"].shape[0], 1), - dtype=torch.int64, - device=device, - ) - tokenized_prompts["attention_mask"] = torch.zeros( - (tokenized_prompts["attention_mask"].shape[0], 1), - dtype=torch.bool, - device=device, - ) - else: - tokenized_prompts["input_ids"] = torch.zeros( - (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), - dtype=torch.int64, - device=device, - ) - tokenized_prompts["attention_mask"] = torch.zeros( - ( - tokenized_prompts["attention_mask"].shape[0], - tokenized_prompts["attention_mask"].shape[1] + 1, - ), - dtype=torch.bool, - device=device, - ) + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + ( + tokenized_prompts["attention_mask"].shape[0], + tokenized_prompts["attention_mask"].shape[1] + 1, + ), + dtype=torch.bool, + device=device, + ) # Broadcast the new token to all the pipeline stages dist.broadcast( From be2da504e50edd64118c97acbaa3d027d3c58ce5 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 16:06:21 +0000 Subject: [PATCH 16/18] make use_cache work with multi parallelism --- src/nanotron/generation/decode.py | 61 +++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 6ab71fad..dc021085 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -772,6 +772,67 @@ def generator(): ) +@torch.inference_mode() +def run_one_inference_step(model, batch, parallel_context, device, use_cache, store): + if dist.get_world_size(group=parallel_context.pp_pg) == 1: + if use_cache: + with attach_store(model=model, store=store): + return model.model(batch.new_input_ids, batch.new_input_mask) + return model.model(batch.input_ids, batch.input_masks) + + pipeline_state = PipelineEvalBatchState() + with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): + batch_size = batch.new_input_ids.shape[0] if use_cache else batch.input_ids.shape[0] + seq_len = batch.new_input_ids.shape[1] if use_cache else batch.input_ids.shape[1] + + # Preallocate memory for output logits. + logits = None + if parallel_context.is_pipeline_last_stage: + logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device) + + if use_cache: + batch2use = GenerationStates( + new_input_ids=batch.new_input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + new_input_mask=batch.new_input_mask + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + store=store, + generation_ids=batch.generation_ids, + generation_mask=batch.generation_mask, + ) + with attach_store(model=model, store=store): + output_tensor = model.model(batch2use.new_input_ids, batch2use.new_input_mask) + else: + batch2use = GenerationInputs( + input_ids=batch.input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + input_masks=batch.input_masks + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + ) + + output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) + + nb_send = len(pipeline_state.microbatches_activations_to_send) + assert nb_send <= 2 + for _ in range(nb_send): + # Send activations to the next stage + # Send attention_mask to the next stage + pipeline_state.run_communication() + + # Copy logits. + if parallel_context.is_pipeline_last_stage: + logits = output_tensor + + # Wait for all the communication to complete. + dist.barrier(group=parallel_context.world_pg) + + return logits + + # Distributed utilities def broadcast_tensors( tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None From c66b4ba6b05356ba20ef990432c2fcfe877a9d69 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 17 Jun 2024 19:33:18 +0000 Subject: [PATCH 17/18] clean generate --- run_generate.py | 280 ++++++++++++++++++++---------------------------- 1 file changed, 114 insertions(+), 166 deletions(-) diff --git a/run_generate.py b/run_generate.py index 1971b0e9..ad116d73 100644 --- a/run_generate.py +++ b/run_generate.py @@ -8,7 +8,6 @@ ``` """ import argparse -import os from pathlib import Path import torch @@ -22,11 +21,8 @@ ) from nanotron.distributed import get_global_rank from nanotron.generation.decode import ( - GenerationInput, GenerationInputs, GenerationStates, - TokenizerConfig, - decode_text, run_one_inference_step, ) from nanotron.generation.generate_store import Store @@ -38,7 +34,6 @@ from nanotron.parallel.pipeline_parallel.engine import ( OneForwardOneBackwardPipelineEngine, ) -from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.random import ( RandomStates, @@ -181,183 +176,136 @@ def main(): log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0) - # NOTE: This doesn't support micro-batches and batch inference yet + # NOTE: This doesn't support micro-batches and batch inference + device = torch.cuda.current_device() + generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache) + logits_are_batch_first = True - if os.environ.get("REFACTO", "0") == "1": - - device = torch.cuda.current_device() - generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache) - logits_are_batch_first = True + if generation_config: + if isinstance(generation_config.sampler, str): + sampler_type = SamplerType(generation_config.sampler.upper()) + else: + sampler_type = generation_config.sampler + else: + sampler_type = SamplerType.GREEDY - if generation_config: - if isinstance(generation_config.sampler, str): - sampler_type = SamplerType(generation_config.sampler.upper()) - else: - sampler_type = generation_config.sampler + tokenized_prompts = tokenizer( + dummy_inputs, + return_tensors="pt", + return_attention_mask=True, + padding=True, + ) + tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) + tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device) + + store = Store() + batch_prompts = None + + for i in range(args.max_new_tokens): + + if generation_config.use_cache: + # Prepare the batch prompts + batch_prompts = GenerationStates( + new_input_ids=tokenized_prompts["input_ids"] + if i == 0 + else tokenized_prompts["input_ids"][:, -1].unsqueeze(0), + new_input_mask=tokenized_prompts["attention_mask"] + if i == 0 + else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0), + store=store, + generation_ids=tokenized_prompts["input_ids"], + generation_mask=tokenized_prompts["attention_mask"], + ) else: - sampler_type = SamplerType.GREEDY + batch_prompts = GenerationInputs( + input_ids=tokenized_prompts["input_ids"], + input_masks=tokenized_prompts["attention_mask"], + ) - tokenized_prompts = tokenizer( - dummy_inputs, - return_tensors="pt", - return_attention_mask=True, - padding=True, - ) - tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) - tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to( - dtype=torch.bool, device=device + logits = run_one_inference_step( + model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store ) - store = Store() - batch_prompts = None - - for i in range(args.max_new_tokens): - - if generation_config.use_cache: - # Prepare the batch prompts - batch_prompts = GenerationStates( - new_input_ids=tokenized_prompts["input_ids"] - if i == 0 - else tokenized_prompts["input_ids"][:, -1].unsqueeze(0), - new_input_mask=tokenized_prompts["attention_mask"] - if i == 0 - else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0), - store=store, - generation_ids=tokenized_prompts["input_ids"], - generation_mask=tokenized_prompts["attention_mask"], - ) + # Sample new token + if parallel_context.is_pipeline_last_stage: + assert logits is not None and isinstance(logits, torch.Tensor) + + # Get sampler + if sampler_type == SamplerType.GREEDY: + sampler = GreedySampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_K: + sampler = TopKSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_P: + sampler = TopPSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.BASIC: + sampler = BasicSampler(pg=parallel_context.tp_pg) else: - batch_prompts = GenerationInputs( - input_ids=tokenized_prompts["input_ids"], - input_masks=tokenized_prompts["attention_mask"], - ) - - logits = run_one_inference_step( - model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store + raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") + + if logits_are_batch_first: + logits = logits.transpose(0, 1) + + # Predict next token + next_token = sampler(sharded_logits=logits[:, -1]) + + # Extend the tokenized prompts to insert the new token + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device), + ], + dim=-1, ) - - # Sample new token - if parallel_context.is_pipeline_last_stage: - assert logits is not None and isinstance(logits, torch.Tensor) - - # Get sampler - if sampler_type == SamplerType.GREEDY: - sampler = GreedySampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.TOP_K: - sampler = TopKSampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.TOP_P: - sampler = TopPSampler(pg=parallel_context.tp_pg) - elif sampler_type == SamplerType.BASIC: - sampler = BasicSampler(pg=parallel_context.tp_pg) - else: - raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") - - if logits_are_batch_first: - logits = logits.transpose(0, 1) - - # Predict next token - next_token = sampler(sharded_logits=logits[:, -1]) - - # Extend the tokenized prompts to insert the new token - tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) - tokenized_prompts["attention_mask"] = torch.cat( - [ - tokenized_prompts["attention_mask"], - torch.ones( - (tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device - ), - ], - dim=-1, - ) - else: - # Extend the tokenized prompts to receive the new token - tokenized_prompts["input_ids"] = torch.zeros( - (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), - dtype=torch.int64, - device=device, - ) - tokenized_prompts["attention_mask"] = torch.zeros( - ( - tokenized_prompts["attention_mask"].shape[0], - tokenized_prompts["attention_mask"].shape[1] + 1, - ), - dtype=torch.bool, - device=device, - ) - - # Broadcast the new token to all the pipeline stages - dist.broadcast( - tokenized_prompts["input_ids"], - src=get_global_rank( - group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank - ), - group=parallel_context.pp_pg, + else: + # Extend the tokenized prompts to receive the new token + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, ) - dist.broadcast( - tokenized_prompts["attention_mask"], - src=get_global_rank( - group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + tokenized_prompts["attention_mask"] = torch.zeros( + ( + tokenized_prompts["attention_mask"].shape[0], + tokenized_prompts["attention_mask"].shape[1] + 1, ), - group=parallel_context.pp_pg, + dtype=torch.bool, + device=device, ) - if dist.get_rank() == 0: - for i, prompt in enumerate(dummy_inputs): - if generation_config.use_cache: - tokenized_outputs = torch.cat( - [tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1 - ) - outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False) - else: - tokenized_outputs = tokenized_prompts["input_ids"][ - i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : - ] - outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) - - log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) - log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) - else: - outputs = decode_text( - input_iter=(GenerationInput(text=text) for text in dummy_inputs), - tokenizer=tokenizer, - # TODO @thomasw21: From ModelWithLoss extract the model. - model=model.model, - parallel_context=parallel_context, - max_new_tokens=args.max_new_tokens, - max_micro_batch_size=1, - generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), - tokenizer_config=TokenizerConfig(max_input_length=None), - is_bench=os.environ.get("USE_BENCH", "0") == "1", + # Broadcast the new token to all the pipeline stages + dist.broadcast( + tokenized_prompts["input_ids"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, + ) + dist.broadcast( + tokenized_prompts["attention_mask"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - - log_rank( - f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", - logger=logger, - level=logging.INFO, - rank=0, - ) + # Decode the generated text + if dist.get_rank() == 0: + for i, prompt in enumerate(dummy_inputs): + if generation_config.use_cache: + tokenized_outputs = torch.cat( + [tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1 + ) + outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False) + else: + tokenized_outputs = tokenized_prompts["input_ids"][ + i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : + ] + outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) - log_rank( - f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", - logger=logger, - level=logging.INFO, - rank=0, - ) + log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, - ) dist.barrier() From 1c17bbbf9ddb3eec2bb8ec0b44b7eeb568e2f88a Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 18 Jun 2024 18:07:22 +0000 Subject: [PATCH 18/18] add pipeline stage attribute --- src/nanotron/parallel/context.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index e04e26f5..aba26dfa 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Tuple, Annotated +from typing import Literal, Tuple import numpy as np import torch @@ -62,6 +62,20 @@ def __init__( self._init_parallel_groups() + self.pipeline_parallel_last_rank = self.pipeline_parallel_size - 1 + self.is_pipeline_first_stage = self.pp_pg.rank() == 0 + self.is_pipeline_last_stage = self.pp_pg.rank() == self.pipeline_parallel_last_rank + self.pipeline_parallel_next_rank = ( + None + if self.is_pipeline_last_stage + else int(self.world_rank_matrix[self.tp_pg.rank(), self.pp_pg.rank() + 1, self.dp_pg.rank()]) + ) + self.pipeline_parallel_prev_rank = ( + None + if self.is_pipeline_first_stage + else int(self.world_rank_matrix[self.tp_pg.rank(), self.pp_pg.rank() - 1, self.dp_pg.rank()]) + ) + def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" dist.barrier() @@ -152,4 +166,4 @@ def get_global_rank( :return: numpy.int64, The global rank. """ - return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank] \ No newline at end of file + return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank]