Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework #531

Merged
merged 12 commits into from
Aug 1, 2024

Conversation

Hanxian97
Copy link
Contributor

@Hanxian97 Hanxian97 commented Jul 19, 2024

Summary:
This is a prototype for mixed-precision quantization. It consists of naive implementation of integer 2/3/5/6-bit quantization. Along with the int4wo and int8wo in torchao, it provides an evaluation framework leveraging lm_eval for mixed-precision quantization on Llama3

Test Plan:
To test the naive implementation of quantization APIs: python test/quantization/test_native_intNwo.py

Copy link

pytorch-bot bot commented Jul 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/531

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e516f0b with merge base 00b76c4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 19, 2024
@Hanxian97 Hanxian97 marked this pull request as draft July 21, 2024 23:47
Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Hanxian97, thanks for the PR. I think you can remove all the files except for mp_quant_eval.py, naive_intNwo.py, and test_naive_intNwo.py. Left a few minor comments other than that


wait
done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Hanxian97, I feel we don't want to push these experiments scripts to torchao. Can you remove them from the PR? (OK to keep in your own separate branch for now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I have removed experiment scripts and only kept mp_quant_eval.py, naive_intNwo.py, and test_naive_intNwo.py.

ZeroPointDomain,
)

def intN_weight_only_asym(group_size=32, n=8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short docstring to describe what this is doing? Maybe add an example to use this with the quantize_ API? (same for intN_weight_only_sym)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, should we limit this to n = [2, 3, 4, 5, 6, 8] for now? (throw error otherwise)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the doctoring and assertion to limit [2,3,4,5,6,8] only

@@ -0,0 +1,95 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might wanna call this file something else since you're about to do the real sensitivity analysis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this file for now and will commit the real sensitivity analysis in milestone2

model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if quantization == "int8dq":
quantize_(model.to(device=device), int8_dynamic_activation_int4_weight())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems wrong? On main it's int8_dynamic_activation_int8_weight:

quantize_(model, int8_dynamic_activation_int8_weight())
.

Actually we can probably just delete this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this for now since we will not use this

@@ -0,0 +1,27 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by the way I think we need to move this to torchao/test if we want it to run as part of CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the test_naive_intNwo.py under test/quantization now

Comment on lines 14 to 29
target_dtype = torch.int8
quant_min = 0
quant_max = 2**n-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should target_dtype be torch.uint8 for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this.

Comment on lines 62 to 65
if sensi_bit == 8:
quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen)
elif sensi_bit == 4:
quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could merge these logic into intN_weight_only_asym I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged them into intN_weight_only now

Comment on lines 22 to 23
bit_zeropoint = 2 # Example value, please adjust as needed
bit_scale = 2 # Example value, please adjust as needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these bytes?

Copy link
Contributor Author

@Hanxian97 Hanxian97 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes these are in bytes. Have fixed this. Thanks!

return total_size_gb

# Example usage
num_elements = 250945664 #number of elements per Llama3 linear layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be calculated from the model instead of hardcoded? also I feel a better integration is just to fix and extend

def get_model_size_in_bytes(model, ignore_embeddings=False):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is a temporary solution for Llama3. Thanks for the suggestion! I will try to generalize it by extend the get_model_size_in_bytes.

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True

def intN_weight_only(group_size=32, n=8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to name this in more detail, since you have different dtypes and asymmetric/symmetric, in this case it's uintN_asymmetric_weight_only (or probably pass around asymmetric/symmetric as an argument)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I passed asymmetric/symmetric as an argument and merged them into intN_weight_only

eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this FLOAT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. Just changed it to INT.

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock. Thanks!

eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be ZeroPointDomain.INT. FLOAT is mainly for the optimized int4 tinygemm kernel right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. Just changed it to INT.

@@ -0,0 +1,46 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call this test_mixed_precision.py to match your prototype folder and for your future test cases as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed it.

@andrewor14 andrewor14 marked this pull request as ready for review July 24, 2024 17:14
test_weight_only_quant(i, False)
print(f"Test passed for {i}-bit using naive intNwo asymmetric quantization implementation")
except Exception as e:
print(f"Exception handled in test loop for {i}-bit asymmetric quantization. Details: {e}")
Copy link
Contributor

@andrewor14 andrewor14 Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to actually raise this exception too? Otherwise it'll be hard to catch the test when it fails

import os
import sys
# append the path to the naive_intNwo.py file
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work if you just add an empty __init__.py to torchao/quantization/prototype/mixed_precision? Then you won't need this line anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added init.py and removed the path append

@Hanxian97 Hanxian97 force-pushed the Hanxian_MixedPrecision branch 4 times, most recently from 75b55c2 to ec36a94 Compare July 30, 2024 16:05
@andrewor14 andrewor14 merged commit c023f71 into main Aug 1, 2024
13 checks passed
jainapurva pushed a commit that referenced this pull request Aug 7, 2024
…k framework (#531)

* milestone1: naive_intNwo + eval/benchmark

* remove experiment scripts

* remove exp files

* use default ZeroPointDomain.INT for int2/3/5/6

* renamed test_naive_intNwo.py to test_mixed_precision.py

* updated intNwo with _get_linear_subclass_inserter

* adjust sqnr threshold according to bit width

* fixed test for int4wo and add __init__.py

* skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly

* edit the sqnr threshold

* add unittest

* correct import path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants