-
Notifications
You must be signed in to change notification settings - Fork 518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix generation for bsz > 1 #1250
Comments
Was this the kind of thing you had in mind? https://github.com/pytorch/torchtune/blob/1129f9e3a246628c991c246d81dbead62d3437a3/torchtune/modules/rlhf/_generation.py Granted, there's a couple changes I've been meaning to make (only generating the full mask once, and extending it for each token in the batch, and you'll probably have a more intelligent way of generating the masks themselves). |
Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right? |
Nah. It was also on my TODO list of possible optimizations, and I briefly spoke to Rafi about it, but we agreed it would be kind of a pain in the ass to setup cacheing for custom masks. |
Left padded:
Left padded mask:
|
I assume batched generation in the eleuther eval recipe satisfies the latter? |
Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:
torchtune/torchtune/modules/transformer.py
Line 167 in 288ff44
For a batch that looks like the following:
A proper mask would look like:
of size [b x s x s], which is [3 x 4 x 4]
This will be a fairly involved change that touches several utils and modules. The general changes needed will be:
mask
parammodel.forward()
This was originally found and reported by @iankur
The text was updated successfully, but these errors were encountered: