Skip to content

Commit

Permalink
inference
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraGPT committed Nov 25, 2023
2 parents 14365c2 + a972bf5 commit 065b917
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
2 changes: 1 addition & 1 deletion scripts/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ python src/llama_recipes/pipeline/finetune.py \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \
--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \
--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \
--batching_strategy custom \
--custom_dataset.max_words 1024 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import whisper


class AudioDataset(Dataset):
class EChatDataset(Dataset):
def __init__(
self,
dataset_config,
Expand Down Expand Up @@ -67,7 +67,8 @@ def __getitem__(self, index):
prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(dialog_list[sentence_id+1]['emotion'], dialog_list[sentence_id+1]['trans'])

prompt_ids = self.tokenizer.encode(prompt) # FIX(GZF)
prompt_ids = self.tokenizer.encode(prompt)

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_pseudo = torch.full((speech_length,),-1)
Expand All @@ -81,8 +82,9 @@ def __getitem__(self, index):
example_ids = torch.cat((speech_pseudo, example_ids)) # [speech,prompt,answer,eos]

labels_ids = copy.deepcopy(example_ids) # [speech,prompt,answer,eos]
labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos]; FIX(zhifu): speech_length + prompt_length->speech_length + prompt_length+1
labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos];
example_mask = example_ids.ge(-1) #FIX(GZF): [True,True,True,True]

label_mask = labels_ids.ge(0) #[False,False,True,True]
example_ids[~example_mask] = 0 #[speech,prompt,answer,eos]
labels_ids[~label_mask] = self.IGNORE_INDEX #[-100,answer,eos,-100]
Expand Down Expand Up @@ -142,8 +144,8 @@ def collator(self, samples):
labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples]) #FIX(GZF): attention_mask
for s in samples])

speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples])
speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0)
for s in samples])
Expand All @@ -162,6 +164,6 @@ def collator(self, samples):


def get_audio_dataset(dataset_config, tokenizer, split):
dataset = AudioDataset(dataset_config, tokenizer, split)
dataset = EChatDataset(dataset_config, tokenizer, split)

return dataset
10 changes: 2 additions & 8 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ def setup_llm(train_config, model_config, **kwargs):
model.print_trainable_parameters()

if kwargs.get("peft_ckpt", None):
# import pdb;
# pdb.set_trace()
# config = PeftConfig.from_pretrained(kwargs.get("peft_ckpt"))
# model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
print("loading ckpt from: ", kwargs.get("peft_ckpt"))
model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt"))

Expand All @@ -135,17 +131,13 @@ def __init__(
self.speech_encoder.eval()

# llama
# peft_ckpt = "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0"
self.llm = setup_llm(train_config, model_config, **kwargs)

# projector
self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size)
ckpt_path = kwargs.get("ckpt_path", None)
# ckpt_path = kwargs.get("ckpt_path", "/nfs/zhifu.gzf/models/llama-2-hf-finetune/echat/0/model.pt")
if ckpt_path is not None:
# load ckpt
# import pdb;
# pdb.set_trace()
print("loading ckpt from: ", ckpt_path)
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
self.load_state_dict(ckpt_dict, strict=False)
Expand Down Expand Up @@ -180,6 +172,7 @@ def forward(self,
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = speech_encoder_outs.shape
speech_encoder_outs_pad = F.pad(speech_encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
Expand Down Expand Up @@ -256,3 +249,4 @@ def generate(
output_text = self.tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)

return output_text

2 changes: 1 addition & 1 deletion src/llama_recipes/pipeline/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
def model_factory(train_config, model_config, **kwargs):

tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
model = setup_model(tokenizer, train_config, model_config, **kwargs)
model = setup_model(tokenizer, train_config, model_config, **kwargs).cuda()

return model, tokenizer

0 comments on commit 065b917

Please sign in to comment.