-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rag e2e first commit #2
Conversation
from transformers import AutoModel | ||
|
||
|
||
class AutoModelForSentenceEmbedding(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add the causal language model to this class. Obviously we can change the class name.
Reason: In future, if we are to use accelerate with deep speed, it won't work with two models for now.
self.normalize = normalize | ||
self.tokenizer = tokenizer | ||
|
||
def forward(self, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you add a casual language model to the above init function. You can easily add another parameter for the forward position and get the output from either the retriever or the generator.
like model_type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting idea
|
||
if is_diffusers_available(): | ||
from .models import ( | ||
DDPOPipelineOutput, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change these naming later, so don't worry.
logprobs_logits, doc_logprobs, query_token_length | ||
) | ||
|
||
loss = get_nll(marginalized_log_probs, input_tensors[:, 1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's get the mean loss here.
batch_size * 1
|
||
# Prepare everything with our `accelerator`. | ||
# see https://github.com/huggingface/accelerate/issues/253#issuecomment-1253231210 | ||
r_model, c_model = accelerator.prepare(r_model, c_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good job. Everything seems great/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct
No description provided.