Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generation for bsz > 1 #1250

Closed
joecummings opened this issue Aug 2, 2024 · 5 comments · Fixed by #1424
Closed

Fix generation for bsz > 1 #1250

joecummings opened this issue Aug 2, 2024 · 5 comments · Fixed by #1424
Assignees
Labels
bug Something isn't working

Comments

@joecummings
Copy link
Contributor

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:

self.causal_mask = torch.tril(

For a batch that looks like the following:

My, name, is, Joe
Hello, world, <PAD>, <PAD>
Bye, <PAD>, <PAD>, <PAD>

A proper mask would look like:

1 0 0 0
1 1 0 0 
1 1 1 0
1 1 1 1

1 0 0 0
1 1 0 0
0 0 0 0
0 0 0 0

1 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0

of size [b x s x s], which is [3 x 4 x 4]

This will be a fairly involved change that touches several utils and modules. The general changes needed will be:

  • Delete causal mask from KV cache, instead opting for this to come in the mask param
  • Update generate utils to pass a mask into its call to model.forward()
  • Modify eleuther eval recipe to construct proper causal mask to pass to

This was originally found and reported by @iankur

@SalmanMohammadi
Copy link
Collaborator

Was this the kind of thing you had in mind? https://github.com/pytorch/torchtune/blob/1129f9e3a246628c991c246d81dbead62d3437a3/torchtune/modules/rlhf/_generation.py

Granted, there's a couple changes I've been meaning to make (only generating the full mask once, and extending it for each token in the batch, and you'll probably have a more intelligent way of generating the masks themselves).

@joecummings
Copy link
Contributor Author

Was this the kind of thing you had in mind? 1129f9e/torchtune/modules/rlhf/_generation.py

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

@SalmanMohammadi
Copy link
Collaborator

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

Nah. It was also on my TODO list of possible optimizations, and I briefly spoke to Rafi about it, but we agreed it would be kind of a pain in the ass to setup cacheing for custom masks.

@joecummings joecummings self-assigned this Aug 2, 2024
@joecummings joecummings added the bug Something isn't working label Aug 2, 2024
@SalmanMohammadi SalmanMohammadi mentioned this issue Aug 2, 2024
19 tasks
@joecummings
Copy link
Contributor Author

joecummings commented Aug 21, 2024

Left padded:

My, name, is, Joe
<PAD>, <PAD> Hello, world 
<PAD>, <PAD>, <PAD>, Bye

Left padded mask:

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

1 0 0 0
0 1 0 0
0 0 1 0 
0 0 1 1

1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1

@SalmanMohammadi
Copy link
Collaborator

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:

I assume batched generation in the eleuther eval recipe satisfies the latter?
I've just got iterative decoding + kv cacheing working for my batched RLHF generation utils - seeing > 10x speedups w/o compile (PPO go brrrr). Can chat about it later today if it's of interest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants