Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(router): Dynamic batch sizing #210

Closed
wants to merge 7 commits into from

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Apr 20, 2023

Motivation

Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.

These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.

  • Batches are filled taking both the input lengths and max new tokens of existing and new requests into account
  • A "projection" is done when evaluating each next request in the queue for admission into the batch, to determine whether the max batch weight will ever be exceeded in future assuming the worst case of all requests running to their max_new_tokens values
  • As long as they did not arrive too far behind, smaller requests can jump ahead of larger ones if the larger one doesn't fit in the batch but the later smaller one does
  • You can optionally set a separate "max prefill weight" to limit how many tokens can be prefilled at once. This is to help avoid long delays where no tokens are produced.

If max_batch_weight is not set, it just infers this from the max_batch_size and max_total_tokens args. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurations

It turns out to be simpler to configure for a particular model/GPU. The precise values for max_batch_size and max_sequence_length no longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.

We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our max_batch_weight setting for this is 10k.

Details/caveats

  • I have ported this from my internal fork and not yet tested this branch
  • I've only included the implementation for the flash attention case so far. The additions to generalize to the regular attention case aren't very big (we run non flash-attention models with this too), but I thought this was probably complicated enough to start with. It will need to support general case of course before actually being included.
  • Some of the logic (e.g. related to the extra fields in the queue state) is to cut down on the overhead of repeated analysis - most calls to next_batch should return immediately before getting into the more complex logic.
  • Since the queue's next_batch function now takes the current entries map instead of a min and max, the tests in queue.rs would need updating, so I just removed them for now.
  • Though it should be fully-functional please consider it still wip - if you are interested, more cleanup/rework can be done

@OlivierDehaene
Copy link
Member

Nice!
I think that makes a lot more sense that the current naive algorithm and is easier to represent mentally.

I need to think about your implementation and maybe play with it but I think it's a good idea.

@njhill
Copy link
Contributor Author

njhill commented Apr 20, 2023

Thanks @OlivierDehaene, I've now rebased it.

@njhill
Copy link
Contributor Author

njhill commented Apr 27, 2023

@OlivierDehaene @Narsil continuing discussion from #246, I've pushed a new commit here to abstract the batch "weight" calculations to cover the non-flash attention case too. We use this for example with flan-t5-xxl where we set the max batch weight to 800 * vram_remaining_in_mb_after_loading_model,

I too am not that happy about how complex the changes are but I'm sure they can still be simplified/restructured a bit, and the queue-jumping logic could be removed or improved as discussed.

@Atry
Copy link
Contributor

Atry commented May 27, 2023

If I understand correctly, currently we send the same batch to all workers, and then each worker run the same tokenization repeatedly, even though the model is sharded.

I wonder if we could let the router split the batch and send different samples to different workers, so that we can avoid running tokenizer on same samples from different workers.

To be honest, I don't understand why we don't split the batch like FSDP or DeepSpeed does. Would it be possible to further reduce memory usage from the split batches?

@Narsil
Copy link
Collaborator

Narsil commented Apr 30, 2024

Closing as stale.

Thanks for the contribution Nick, happy to take some back now that we're back on Apache !

Cheers.

@Narsil Narsil closed this Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants