-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile() the quantization method #116
base: master
Are you sure you want to change the base?
Conversation
Thank you @rationalism ! Added a few comments |
@mobicham thank you! I don't see the comments though? |
Oh, in the review, you don't see this https://github.com/mobiusml/hqq/pull/116/files/631ea011d8432b8a76518b0adc072574969d8771 ? |
@mobicham I don't see any comments or review, no, sorry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mobicham I don't see any comments or review, no, sorry
I think you should see it now
@@ -218,15 +218,15 @@ def optimize_weights_proximal_legacy( | |||
scale = scale.to(dtype=dtype, device=device) | |||
zero = zero.to(dtype=dtype, device=device) | |||
|
|||
best_error = 1e4 | |||
best_error = torch.tensor(1e4, dtype=float32, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best_error = torch.tensor(1e4, dtype=dtype, device=device)
for i in range(iters): | ||
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1]) | ||
W_r = (W_q - zero) / scale | ||
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm) | ||
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True) | ||
beta *= kappa | ||
|
||
current_error = float(torch.abs(W_f - W_r).mean()) | ||
current_error = torch.abs(W_f - W_r).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current_error = torch.abs(W_f - W_r).mean().float()
@@ -72,6 +72,7 @@ class Quantizer: | |||
} | |||
|
|||
@classmethod | |||
@torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used torch compile back in the beginning for the quantization step but the issue is that the graph would break since there are too many ops. Also, it would break on older systems and and the warm-up sometimes would be quite slow.
For all those reasons, I removed it. That was a couple of months ago though and torch has improved the compiling process a lot. However, I still think it should be turned off by default.
Maybe for the moment, we can remove it as a decorator, and do it outside like this:
Quantizer.quantize = torch.compile(Quantizer.quantize)
Have you tried compiling only the optimizer?
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
If that works, I would suspect the memory increase is due to shrink_lp_op()
. In which case, we can rewrite it to do the operations in-place.
I just tried this one and it compiles without graph breaks: @torch.inference_mode()
def optimize_weights_proximal_legacy(
tensor: Tensor,
scale: Tensor,
zero: Tensor,
min_max: list,
axis: int = 0,
device: Union[str, None] = None,
opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20},
verbose: bool = False,
) -> tuple:
if device is None:
device = tensor.device
else:
device = torch.device(device)
dtype = float16 if (device.type == "cuda") else float32
W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)
# Params
lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
beta = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
kappa = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
iters = opt_params["iters"]
best_error = torch.tensor(1e4, dtype=torch.float32, device=device)
for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa
current_error = torch.abs(W_f - W_r).mean().float()
if verbose:
print(i, np.round(current_error, 6))
if current_error < best_error:
best_error = current_error
else:
break
scale = scale.to(tensor.device)
zero = zero.to(tensor.device)
del W_f, W_q, W_r, W_e
torch.cuda.empty_cache()
W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
return W_q, scale, zero then import torch
device = 'cuda:0'
backend = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit)
compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16
cache_dir = '.'
model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
########################################################################
#Load model
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *
#Load
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")
#Quantize
#torch._dynamo.config.capture_scalar_outputs = True
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
Quantizer.quantize = torch.compile(Quantizer.quantize)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device) I wil ltake a look at |
Works with Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights) or like this Quantizer.quantize = torch.compile(Quantizer.quantize) This way if the compilation breaks the user can simply skip this step. Then in ....
dtype = float16 if (device.type == "cuda") else float32
W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)
# Params
lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
beta = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
kappa = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
iters = opt_params["iters"]
.... Ideally we do this for the 🙏 |
Decorate the quantize() method with torch.compile. This uses much less GPU VRAM, avoiding OOM when quantizing Llama-3.1-405B (which previously OOMed on my machine even when quantizing one layer at a time).
@mobicham