-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework #531
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/531
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e516f0b with merge base 00b76c4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Hanxian97, thanks for the PR. I think you can remove all the files except for mp_quant_eval.py
, naive_intNwo.py
, and test_naive_intNwo.py
. Left a few minor comments other than that
|
||
wait | ||
done | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Hanxian97, I feel we don't want to push these experiments scripts to torchao. Can you remove them from the PR? (OK to keep in your own separate branch for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I have removed experiment scripts and only kept mp_quant_eval.py, naive_intNwo.py, and test_naive_intNwo.py.
ZeroPointDomain, | ||
) | ||
|
||
def intN_weight_only_asym(group_size=32, n=8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a short docstring to describe what this is doing? Maybe add an example to use this with the quantize_
API? (same for intN_weight_only_sym
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, should we limit this to n = [2, 3, 4, 5, 6, 8] for now? (throw error otherwise)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the doctoring and assertion to limit [2,3,4,5,6,8] only
@@ -0,0 +1,95 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might wanna call this file something else since you're about to do the real sensitivity analysis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this file for now and will commit the real sensitivity analysis in milestone2
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) | ||
|
||
if quantization == "int8dq": | ||
quantize_(model.to(device=device), int8_dynamic_activation_int4_weight()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems wrong? On main it's int8_dynamic_activation_int8_weight
:
Line 52 in 0e6c122
quantize_(model, int8_dynamic_activation_int8_weight()) |
Actually we can probably just delete this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this for now since we will not use this
@@ -0,0 +1,27 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by the way I think we need to move this to torchao/test
if we want it to run as part of CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the test_naive_intNwo.py under test/quantization now
target_dtype = torch.int8 | ||
quant_min = 0 | ||
quant_max = 2**n-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should target_dtype be torch.uint8
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this.
if sensi_bit == 8: | ||
quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen) | ||
elif sensi_bit == 4: | ||
quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could merge these logic into intN_weight_only_asym
I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merged them into intN_weight_only now
bit_zeropoint = 2 # Example value, please adjust as needed | ||
bit_scale = 2 # Example value, please adjust as needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these bytes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes these are in bytes. Have fixed this. Thanks!
return total_size_gb | ||
|
||
# Example usage | ||
num_elements = 250945664 #number of elements per Llama3 linear layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be calculated from the model instead of hardcoded? also I feel a better integration is just to fix and extend
Line 188 in 5787e9e
def get_model_size_in_bytes(model, ignore_embeddings=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is a temporary solution for Llama3. Thanks for the suggestion! I will try to generalize it by extend the get_model_size_in_bytes.
torch._inductor.config.force_fuse_int_mm_with_mul = True | ||
torch._inductor.config.fx_graph_cache = True | ||
|
||
def intN_weight_only(group_size=32, n=8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest to name this in more detail, since you have different dtypes and asymmetric/symmetric, in this case it's uintN_asymmetric_weight_only (or probably pass around asymmetric/symmetric as an argument)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I passed asymmetric/symmetric as an argument and merged them into intN_weight_only
eps = 1e-6 | ||
preserve_zero = False | ||
zero_point_dtype = torch.bfloat16 | ||
zero_point_domain = ZeroPointDomain.FLOAT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this FLOAT
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out. Just changed it to INT.
aadda53
to
e9f56d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock. Thanks!
eps = 1e-6 | ||
preserve_zero = False | ||
zero_point_dtype = torch.bfloat16 | ||
zero_point_domain = ZeroPointDomain.FLOAT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be ZeroPointDomain.INT. FLOAT is mainly for the optimized int4 tinygemm kernel right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out. Just changed it to INT.
@@ -0,0 +1,46 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call this test_mixed_precision.py
to match your prototype folder and for your future test cases as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed it.
test_weight_only_quant(i, False) | ||
print(f"Test passed for {i}-bit using naive intNwo asymmetric quantization implementation") | ||
except Exception as e: | ||
print(f"Exception handled in test loop for {i}-bit asymmetric quantization. Details: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might want to actually raise this exception too? Otherwise it'll be hard to catch the test when it fails
import os | ||
import sys | ||
# append the path to the naive_intNwo.py file | ||
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work if you just add an empty __init__.py
to torchao/quantization/prototype/mixed_precision
? Then you won't need this line anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added init.py and removed the path append
75b55c2
to
ec36a94
Compare
ec36a94
to
f4fccf3
Compare
…k framework (#531) * milestone1: naive_intNwo + eval/benchmark * remove experiment scripts * remove exp files * use default ZeroPointDomain.INT for int2/3/5/6 * renamed test_naive_intNwo.py to test_mixed_precision.py * updated intNwo with _get_linear_subclass_inserter * adjust sqnr threshold according to bit width * fixed test for int4wo and add __init__.py * skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly * edit the sqnr threshold * add unittest * correct import path
Summary:
This is a prototype for mixed-precision quantization. It consists of naive implementation of integer 2/3/5/6-bit quantization. Along with the int4wo and int8wo in torchao, it provides an evaluation framework leveraging lm_eval for mixed-precision quantization on Llama3
Test Plan:
To test the naive implementation of quantization APIs: python test/quantization/test_native_intNwo.py