-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(router): Dynamic batch sizing #210
Conversation
Nice! I need to think about your implementation and maybe play with it but I think it's a good idea. |
c89260f
to
ba1aae3
Compare
Thanks @OlivierDehaene, I've now rebased it. |
@OlivierDehaene @Narsil continuing discussion from #246, I've pushed a new commit here to abstract the batch "weight" calculations to cover the non-flash attention case too. We use this for example with flan-t5-xxl where we set the max batch weight to I too am not that happy about how complex the changes are but I'm sure they can still be simplified/restructured a bit, and the queue-jumping logic could be removed or improved as discussed. |
If I understand correctly, currently we send the same batch to all workers, and then each worker run the same tokenization repeatedly, even though the model is sharded. I wonder if we could let the router split the batch and send different samples to different workers, so that we can avoid running tokenizer on same samples from different workers. To be honest, I don't understand why we don't split the batch like FSDP or DeepSpeed does. Would it be possible to further reduce memory usage from the split batches? |
Closing as stale. Thanks for the contribution Nick, happy to take some back now that we're back on Apache ! Cheers. |
Motivation
Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.
These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.
max_new_tokens
valuesIf
max_batch_weight
is not set, it just infers this from themax_batch_size
andmax_total_tokens
args. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurationsIt turns out to be simpler to configure for a particular model/GPU. The precise values for
max_batch_size
andmax_sequence_length
no longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our
max_batch_weight
setting for this is 10k.Details/caveats
I've only included the implementation for the flash attention case so far. The additions to generalize to the regular attention case aren't very big (we run non flash-attention models with this too), but I thought this was probably complicated enough to start with. It will need to support general case of course before actually being included.next_batch
should return immediately before getting into the more complex logic.next_batch
function now takes the current entries map instead of a min and max, the tests inqueue.rs
would need updating, so I just removed them for now.