Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save/load ckpt, inference #1

Merged
merged 5 commits into from
Nov 27, 2023
Merged

Conversation

LauraGPT
Copy link
Collaborator

No description provided.

prompt_length = len(prompt_ids)
speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
speech_pseudo = torch.full((speech_length,),-1)

example = prompt + answer
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids = self.tokenizer.encode(answer) # FIX(GZF): [answer]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need fix for tokenizer.encode(will add a bos token)

self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0] ,self.llm.config.hidden_size)

self.speech_encoder_projector = nn.Linear(self.speech_encoder.ln_post.normalized_shape[0], self.llm.config.hidden_size)
ckpt_path = kwargs.get("ckpt_path", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load checkpoint will move to model_factory: setup_model

@@ -152,18 +163,90 @@ def forward(self,
speech_encoder_outs = None
if speech_mel is not None:
speech_encoder_outs = self.speech_encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1))
speech_encoder_outs = self.speech_encoder_projector(speech_encoder_outs)
speech_encoder_outs = self.speech_encoder_projector.to(speech_encoder_outs.device)(speech_encoder_outs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need fix, move to finetune.py: main

@ddlBoJack ddlBoJack merged commit 1d8c66e into debug-mzy-20231020 Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants