Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save/load ckpt, inference #1

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.DS_Store
__pycache__
.ipynb_checkpoints
.idea/*
transformers
16 changes: 9 additions & 7 deletions scripts/finetune.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth
speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt
llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune
speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/finetune.py \
#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
python src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--use_peft --peft_method lora \
--quantization \
Expand All @@ -27,4 +27,6 @@ python src/llama_recipes/pipeline/finetune.py \
--custom_dataset.max_words 1024 \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir
--output_dir $output_dir \
--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \
--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0"
32 changes: 32 additions & 0 deletions scripts/inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/zhifu.gzf/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
python src/llama_recipes/pipeline/inference.py \
--model_name echat \
--use_peft --peft_method lora \
--quantization \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir \
--ckpt_path "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt" \
--peft_ckpt "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0"
20 changes: 12 additions & 8 deletions src/llama_recipes/datasets/echat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,25 @@ def __getitem__(self, index):

prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(dialog_list[sentence_id+1]['emotion'], dialog_list[sentence_id+1]['trans'])

prompt_ids = self.tokenizer.encode(prompt)

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_pseudo = torch.full((speech_length,),-1)

example = prompt + answer
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids = self.tokenizer.encode(answer) # FIX(GZF): [answer]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need fix for tokenizer.encode(will add a bos token)

example_ids = prompt_ids + example_ids
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
example_ids = torch.tensor(
example_ids, dtype=torch.int64
)
example_ids = torch.cat((speech_pseudo, example_ids)) # [speech,prompt,answer,eos]

labels_ids = copy.deepcopy(example_ids) # [speech,prompt,answer,eos]
labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos]
example_mask = example_ids.ge(-1) #[True,True,True,True]
labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos];
example_mask = example_ids.ge(-1) #FIX(GZF): [True,True,True,True]

label_mask = labels_ids.ge(0) #[False,False,True,True]
example_ids[~example_mask] = 0 #[speech,prompt,answer,eos]
labels_ids[~label_mask] = self.IGNORE_INDEX #[-100,answer,eos,-100]
Expand All @@ -91,7 +94,8 @@ def __getitem__(self, index):
"labels": labels_ids,
"attention_mask": example_mask,
'speech_mel': speech_mel,
'speech_length': speech_length
'speech_length': speech_length,

}


Expand Down Expand Up @@ -139,16 +143,16 @@ def collator(self, samples):
for s in samples])
labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples])

speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples])
speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0)
for s in samples])

speech_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
speech_mask[line, :sample['speech_length']] = 1
speech_mask[line, :sample['speech_length']] = 1 #FIX(GZF): sample['speech_length']+1

return {
'input_ids': input_ids,
Expand Down
3 changes: 2 additions & 1 deletion src/llama_recipes/model_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
load_model_sharded,
load_sharded_model_single_gpu
load_sharded_model_single_gpu,
save_model_checkpoint_peft,
)
19 changes: 17 additions & 2 deletions src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
from pathlib import Path
from datetime import datetime
import torch
Expand Down Expand Up @@ -160,7 +160,22 @@ def save_model_checkpoint(

print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")


def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch))
os.makedirs(save_dir, exist_ok=True)
model.llm.save_pretrained(save_dir)

save_full_path = os.path.join(save_dir, "model.pt")
cpu_state = model.state_dict()
project_dict = {}
for key in cpu_state.keys():
if "_projector" in key:
project_dict[key] = cpu_state[key]
torch.save(project_dict, save_full_path)

print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")



def load_model_checkpoint(model, rank, cfg):
"""load local checkpoint to rank0 cpu
Expand Down
97 changes: 90 additions & 7 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import types
import torch
import soundfile as sf
Expand All @@ -14,7 +15,8 @@

from llama_recipes.utils.config_utils import generate_peft_config
from llama_recipes.utils.train_utils import print_model_size

from peft import PeftModel, PeftConfig
from torch.nn import CrossEntropyLoss

def setup_model(tokenizer, train_config, model_config, **kwargs):
return slam_model(tokenizer, train_config, model_config, **kwargs)
Expand Down Expand Up @@ -104,6 +106,10 @@ def setup_llm(train_config, model_config, **kwargs):
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

if kwargs.get("peft_ckpt", None):
print("loading ckpt from: ", kwargs.get("peft_ckpt"))
model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt"))

return model

Expand All @@ -128,8 +134,13 @@ def __init__(
self.llm = setup_llm(train_config, model_config, **kwargs)

# projector
self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0] ,self.llm.config.hidden_size)

self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size)
ckpt_path = kwargs.get("ckpt_path", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load checkpoint will move to model_factory: setup_model

# ckpt_path = kwargs.get("ckpt_path", "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt")
if ckpt_path is not None:
print("loading ckpt from: ", ckpt_path)
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
self.load_state_dict(ckpt_dict, strict=False)
# tokenizer
self.tokenizer = tokenizer

Expand All @@ -152,18 +163,90 @@ def forward(self,
speech_encoder_outs = None
if speech_mel is not None:
speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1))
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs)
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need fix, move to finetune.py: main


input_ids[input_ids == -1] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
else:
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = speech_encoder_outs.shape
speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
inputs_embeds = speech_encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None])

model_outputs = self.llm.to(speech_encoder_outs.device)(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)

return model_outputs

@torch.no_grad()
def generate(
self,
wav_path = None,
generation_config = None,
logits_processor = None,
stopping_criteria = None,
prefix_allowed_tokens_fn = None,
synced_gpus = None,
assistant_model = None,
streamer = None,
negative_prompt_ids = None,
negative_prompt_attention_mask = None,
**kwargs,
):

device = kwargs.get("device", "cuda")
assert os.path.exists(wav_path)
speech_raw = whisper.load_audio(wav_path)
# speech_raw = whisper.pad_or_trim(speech_raw)
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device)

speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1))
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs)

prompt="""
Please provide an emotional response based on the emotional speech you hear.
Remember to format your answer as follows: <|EMOTION|><|REPLY|>.
<|EMOTION|> is a standalone adjective.
<|REPLY|> is a reply based on a the speech.
"""
prompt = "USER: {}\n ASSISTANT:".format(prompt)
prompt_ids = self.tokenizer.encode(prompt) # FIX(GZF)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)

if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)

inputs_embeds = torch.cat((speech_encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt]

atts = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device)

# generate
output = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=kwargs.get("max_length", 200),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 0.9),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=atts,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)

output_text = self.tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)

model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
return output_text

return model_outputs
42 changes: 42 additions & 0 deletions src/llama_recipes/pipeline/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import fire
import random
import torch
# import argparse
from llama_recipes.models.slam_model import slam_model
# config
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.configs import model_config as MODEL_CONFIG
from llama_recipes.utils.config_utils import (
update_config,
generate_peft_config,
generate_dataset_config,
get_dataloader_kwargs,
)
from llama_recipes.pipeline.model_factory import model_factory

def main(**kwargs):

# Update the configuration for the training and sharding process
train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG()
update_config((train_config, fsdp_config, model_config), **kwargs)

# Set the seeds for reproducibility
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)

model, tokenizer = model_factory(train_config, model_config, **kwargs)
model.to(kwargs.get("device", "cuda"))
model.eval()

print("=====================================")
# wav_path = input("Your Wav Path:\n")
# prompt = input("Your Prompt:\n")
wav_path = "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M001.wav"
print(model.generate(wav_path))



if __name__ == "__main__":
fire.Fire(main)
7 changes: 5 additions & 2 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import LlamaTokenizer


from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_model_checkpoint_peft
from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace

Expand Down Expand Up @@ -144,7 +144,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
Expand Down