Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): reduce memory requirement #214

Merged
merged 4 commits into from
Apr 24, 2023

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Apr 21, 2023

Currently the memory requirement is much higher than it needs to be due the way that past key values are pruned and concatenated.

This PR changes CausalLM to do these operations layer by layer so that the corresponding tensors can be freed incrementally. It should roughly halve the memory requirement and so allow much larger batch sizes and/or sequence lengths.

Also changed the pruning to remove any unused padding at the front the batch (for example when the longest input sequence in the batch gets pruned).

These changes are also kind of required for the non-flash-attention impl of #210 (which I can push soon), to ensure that assumptions around mem usage based on point-in-time batch contents remain valid.

@OlivierDehaene if you approve of this I can add the corresponding Seq2SeqLM changes.

@OlivierDehaene
Copy link
Member

@njhill, tests are failing but I like the idea.

@njhill njhill changed the title feat(server): reduce GPU memory requirement feat(server): reduce memory requirement Apr 21, 2023
@njhill
Copy link
Contributor Author

njhill commented Apr 21, 2023

@OlivierDehaene thanks, let me fix the tests soon.

@njhill
Copy link
Contributor Author

njhill commented Apr 24, 2023

@OlivierDehaene tests fixed and I've also added the same for seq2seq_lm.

Copy link
Member

@OlivierDehaene OlivierDehaene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers!

@OlivierDehaene OlivierDehaene merged commit 4a7dd40 into huggingface:main Apr 24, 2023
@njhill njhill deleted the tgi/mem_reduction branch April 24, 2023 13:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants