-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save/load ckpt, inference #1
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
.DS_Store | ||
__pycache__ | ||
.ipynb_checkpoints | ||
.idea/* | ||
transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
#export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export CUDA_LAUNCH_BLOCKING=1 | ||
|
||
cd /root/SLAM-LLM | ||
|
||
speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt | ||
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf | ||
output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ | ||
python src/llama_recipes/pipeline/inference.py \ | ||
--model_name echat \ | ||
--use_peft --peft_method lora \ | ||
--quantization \ | ||
--llm_name llama-2-7b-hf \ | ||
--llm_path $llm_path \ | ||
--encoder_name whisper \ | ||
--encoder_path $speech_encoder_path \ | ||
--encoder_projector linear \ | ||
--dataset custom_dataset \ | ||
--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \ | ||
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ | ||
--batching_strategy custom \ | ||
--custom_dataset.max_words 1024 \ | ||
--num_epochs 1 \ | ||
--batch_size_training 2 \ | ||
--output_dir $output_dir \ | ||
--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \ | ||
--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import os | ||
import types | ||
import torch | ||
import soundfile as sf | ||
|
@@ -14,7 +15,8 @@ | |
|
||
from llama_recipes.utils.config_utils import generate_peft_config | ||
from llama_recipes.utils.train_utils import print_model_size | ||
|
||
from peft import PeftModel, PeftConfig | ||
from torch.nn import CrossEntropyLoss | ||
|
||
def setup_model(tokenizer, train_config, model_config, **kwargs): | ||
return slam_model(tokenizer, train_config, model_config, **kwargs) | ||
|
@@ -104,6 +106,10 @@ def setup_llm(train_config, model_config, **kwargs): | |
peft_config = generate_peft_config(train_config, kwargs) | ||
model = get_peft_model(model, peft_config) | ||
model.print_trainable_parameters() | ||
|
||
if kwargs.get("peft_ckpt", None): | ||
print("loading ckpt from: ", kwargs.get("peft_ckpt")) | ||
model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt")) | ||
|
||
return model | ||
|
||
|
@@ -128,8 +134,13 @@ def __init__( | |
self.llm = setup_llm(train_config, model_config, **kwargs) | ||
|
||
# projector | ||
self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0] ,self.llm.config.hidden_size) | ||
|
||
self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size) | ||
ckpt_path = kwargs.get("ckpt_path", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. load checkpoint will move to model_factory: setup_model |
||
# ckpt_path = kwargs.get("ckpt_path", "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt") | ||
if ckpt_path is not None: | ||
print("loading ckpt from: ", ckpt_path) | ||
ckpt_dict = torch.load(ckpt_path, map_location="cpu") | ||
self.load_state_dict(ckpt_dict, strict=False) | ||
# tokenizer | ||
self.tokenizer = tokenizer | ||
|
||
|
@@ -152,18 +163,90 @@ def forward(self, | |
speech_encoder_outs = None | ||
if speech_mel is not None: | ||
speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) | ||
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs) | ||
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need fix, move to finetune.py: main |
||
|
||
input_ids[input_ids == -1] = 0 | ||
if hasattr(self.llm.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.embed_tokens(input_ids) | ||
else: | ||
elif hasattr(self.llm.model.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.model.embed_tokens(input_ids) | ||
else: | ||
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) | ||
|
||
batch_size, token_num, dims = inputs_embeds.shape | ||
_, l, _ = speech_encoder_outs.shape | ||
speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) | ||
inputs_embeds = speech_encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) | ||
|
||
model_outputs = self.llm.to(speech_encoder_outs.device)(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) | ||
|
||
return model_outputs | ||
|
||
@torch.no_grad() | ||
def generate( | ||
self, | ||
wav_path = None, | ||
generation_config = None, | ||
logits_processor = None, | ||
stopping_criteria = None, | ||
prefix_allowed_tokens_fn = None, | ||
synced_gpus = None, | ||
assistant_model = None, | ||
streamer = None, | ||
negative_prompt_ids = None, | ||
negative_prompt_attention_mask = None, | ||
**kwargs, | ||
): | ||
|
||
device = kwargs.get("device", "cuda") | ||
assert os.path.exists(wav_path) | ||
speech_raw = whisper.load_audio(wav_path) | ||
# speech_raw = whisper.pad_or_trim(speech_raw) | ||
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device) | ||
|
||
speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) | ||
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs) | ||
|
||
prompt=""" | ||
Please provide an emotional response based on the emotional speech you hear. | ||
Remember to format your answer as follows: <|EMOTION|><|REPLY|>. | ||
<|EMOTION|> is a standalone adjective. | ||
<|REPLY|> is a reply based on a the speech. | ||
""" | ||
prompt = "USER: {}\n ASSISTANT:".format(prompt) | ||
prompt_ids = self.tokenizer.encode(prompt) # FIX(GZF) | ||
prompt_length = len(prompt_ids) | ||
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) | ||
|
||
if hasattr(self.llm.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.embed_tokens(prompt_ids) | ||
elif hasattr(self.llm.model.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) | ||
else: | ||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) | ||
|
||
inputs_embeds = torch.cat((speech_encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt] | ||
|
||
atts = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device) | ||
|
||
# generate | ||
output = self.llm.generate( | ||
inputs_embeds=inputs_embeds, | ||
max_length=kwargs.get("max_length", 200), | ||
num_beams=kwargs.get("num_beams", 1), | ||
do_sample=kwargs.get("do_sample", True), | ||
min_length=kwargs.get("min_length", 1), | ||
top_p=kwargs.get("top_p", 0.9), | ||
repetition_penalty=kwargs.get("repetition_penalty", 1.0), | ||
length_penalty=kwargs.get("length_penalty", 1.0), | ||
temperature=kwargs.get("temperature", 1.0), | ||
attention_mask=atts, | ||
bos_token_id=self.tokenizer.bos_token_id, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
pad_token_id=self.tokenizer.pad_token_id | ||
) | ||
|
||
output_text = self.tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True) | ||
|
||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) | ||
return output_text | ||
|
||
return model_outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import fire | ||
import random | ||
import torch | ||
# import argparse | ||
from llama_recipes.models.slam_model import slam_model | ||
# config | ||
from llama_recipes.configs import fsdp_config as FSDP_CONFIG | ||
from llama_recipes.configs import train_config as TRAIN_CONFIG | ||
from llama_recipes.configs import model_config as MODEL_CONFIG | ||
from llama_recipes.utils.config_utils import ( | ||
update_config, | ||
generate_peft_config, | ||
generate_dataset_config, | ||
get_dataloader_kwargs, | ||
) | ||
from llama_recipes.pipeline.model_factory import model_factory | ||
|
||
def main(**kwargs): | ||
|
||
# Update the configuration for the training and sharding process | ||
train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() | ||
update_config((train_config, fsdp_config, model_config), **kwargs) | ||
|
||
# Set the seeds for reproducibility | ||
torch.cuda.manual_seed(train_config.seed) | ||
torch.manual_seed(train_config.seed) | ||
random.seed(train_config.seed) | ||
|
||
model, tokenizer = model_factory(train_config, model_config, **kwargs) | ||
model.to(kwargs.get("device", "cuda")) | ||
model.eval() | ||
|
||
print("=====================================") | ||
# wav_path = input("Your Wav Path:\n") | ||
# prompt = input("Your Prompt:\n") | ||
wav_path = "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M001.wav" | ||
print(model.generate(wav_path)) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need fix for tokenizer.encode(will add a bos token)