Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile() the quantization method #116

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

rationalism
Copy link

@rationalism rationalism commented Sep 10, 2024

Decorate the quantize() method with torch.compile. This uses much less GPU VRAM, avoiding OOM when quantizing Llama-3.1-405B (which previously OOMed on my machine even when quantizing one layer at a time).

@mobicham

@mobicham
Copy link
Collaborator

Thank you @rationalism ! Added a few comments

@rationalism
Copy link
Author

@mobicham thank you! I don't see the comments though?

@mobicham
Copy link
Collaborator

@rationalism
Copy link
Author

@mobicham I don't see any comments or review, no, sorry

Copy link
Collaborator

@mobicham mobicham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mobicham I don't see any comments or review, no, sorry

I think you should see it now

@@ -218,15 +218,15 @@ def optimize_weights_proximal_legacy(
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)

best_error = 1e4
best_error = torch.tensor(1e4, dtype=float32, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best_error = torch.tensor(1e4, dtype=dtype, device=device)

for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa

current_error = float(torch.abs(W_f - W_r).mean())
current_error = torch.abs(W_f - W_r).mean()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 current_error = torch.abs(W_f - W_r).mean().float()

@@ -72,6 +72,7 @@ class Quantizer:
}

@classmethod
@torch.compile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used torch compile back in the beginning for the quantization step but the issue is that the graph would break since there are too many ops. Also, it would break on older systems and and the warm-up sometimes would be quite slow.
For all those reasons, I removed it. That was a couple of months ago though and torch has improved the compiling process a lot. However, I still think it should be turned off by default.
Maybe for the moment, we can remove it as a decorator, and do it outside like this:

Quantizer.quantize = torch.compile(Quantizer.quantize)

Have you tried compiling only the optimizer?

Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)

If that works, I would suspect the memory increase is due to shrink_lp_op(). In which case, we can rewrite it to do the operations in-place.

@mobicham
Copy link
Collaborator

mobicham commented Sep 12, 2024

I just tried this one and it compiles without graph breaks:

@torch.inference_mode()
def optimize_weights_proximal_legacy(
    tensor: Tensor,
    scale: Tensor,
    zero: Tensor,
    min_max: list,
    axis: int = 0,
    device: Union[str, None] = None,
    opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20},
    verbose: bool = False,
) -> tuple:

    if device is None:
        device = tensor.device
    else:
        device = torch.device(device)

    dtype = float16 if (device.type == "cuda") else float32
    W_f = tensor.to(dtype=dtype, device=device)
    scale = scale.to(dtype=dtype, device=device)
    zero = zero.to(dtype=dtype, device=device)

    # Params
    lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
    beta    = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
    kappa   = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
    iters   = opt_params["iters"]

    best_error = torch.tensor(1e4, dtype=torch.float32, device=device)
    for i in range(iters):
        W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
        W_r = (W_q - zero) / scale
        W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
        zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
        beta *= kappa

        current_error = torch.abs(W_f - W_r).mean().float()
        if verbose:
            print(i, np.round(current_error, 6))
        if current_error < best_error:
            best_error = current_error
        else:
            break

    scale = scale.to(tensor.device)
    zero = zero.to(tensor.device)
    del W_f, W_q, W_r, W_e
    torch.cuda.empty_cache()

    W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
    return W_q, scale, zero

then

import torch
device        = 'cuda:0'
backend       = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit)
compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16
cache_dir     = '.' 
model_id      = 'meta-llama/Meta-Llama-3-8B-Instruct'

########################################################################
#Load model
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *

#Load
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
model     = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")

#Quantize
#torch._dynamo.config.capture_scalar_outputs = True
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
Quantizer.quantize = torch.compile(Quantizer.quantize)

quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)

I wil ltake a look at Quantizer.quantize compilation now

@mobicham
Copy link
Collaborator

mobicham commented Sep 12, 2024

Works with Quantizer.quantize compiled as well! I suggest we do the following:
-We remove the @torch.compile decorator and do the compilation outside, either like this

Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)

or like this

Quantizer.quantize = torch.compile(Quantizer.quantize)

This way if the compilation breaks the user can simply skip this step.

Then in optimize_weights..., we need to move the parameters after dettermining the dtype, this way they have the exact same dtype as the other tensors:

     ....
    dtype = float16 if (device.type == "cuda") else float32
    W_f = tensor.to(dtype=dtype, device=device)
    scale = scale.to(dtype=dtype, device=device)
    zero = zero.to(dtype=dtype, device=device)
    # Params
    lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
    beta    = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
    kappa   = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
    iters   = opt_params["iters"]
   ....

Ideally we do this for the optimize_weights_proximal_v2 version as well.

🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants