Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix per_token slowdown #57

Merged
merged 2 commits into from
May 20, 2024
Merged

Fix per_token slowdown #57

merged 2 commits into from
May 20, 2024

Conversation

Satrat
Copy link

@Satrat Satrat commented May 17, 2024

We had vectorized the scale and zeropoint calculations for the minmax observer, but the memoryless observer was still using the un-vectorized code. Moved the get_qparams_along_dim to the base observer class so all observers use it, and resolved some shape issues.

With this change the token,channel and tensor strategies can all use the same logic for calling quant/dequant so this simplified the forward pass code a lot too

Testing

Running a llama1.1b model with w8a8 dynamic per token:

Before change: 10 sec/iteration
After change: 4.5 iterations/sec

A huge speedup!

@Satrat Satrat requested review from bfineran, horheynm, dbogunowicz, dsikka and rahul-tuli and removed request for horheynm May 17, 2024 17:38
@Satrat Satrat requested a review from bfineran May 17, 2024 19:23
@bfineran bfineran merged commit f9d8d8b into main May 20, 2024
1 check passed
@bfineran bfineran deleted the sa/fix_token_speed branch May 20, 2024 14:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants