Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework #531

Merged
merged 12 commits into from
Aug 1, 2024
32 changes: 32 additions & 0 deletions test/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

import torch
import torch.nn as nn
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
from torchao.quantization.utils import compute_error
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only

_CUDA_IS_AVAILABLE = torch.cuda.is_available()

class TestWeightOnlyQuantNaive(unittest.TestCase):

def test_quantization_intNwo(self):
#skip test int4wo for now since it is under development in torchao
for quantization_bit in [2, 3, 5, 6, 8]:
for symmetric in [False, True]:
with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric):
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]:
x = torch.randn(*x_shape, dtype=torch.bfloat16)
m = nn.Sequential(nn.Linear(32, 80)).bfloat16()
y_ref = m(x)
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric))
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
# SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
# e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills
expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02
self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low")


if __name__ == '__main__':
unittest.main()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .naive_intNwo import intN_weight_only
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn

from naive_intNwo import intN_weight_only
from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_eval.models.huggingface import HFLM
from lm_eval.evaluator import evaluate
from lm_eval.tasks import get_task_dict

from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight
from torchao._models._eval import TransformerEvalWrapper

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.quant_api import autoquant


torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True


def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if quantization == "autoquant":
model = autoquant(model.to(device=device))

# naive implementation of uniform precision quantization all layers
elif quantization in ["2","3","4","5","6","8"]:
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym))

# mix precision quantization for Llama3
elif quantization == "MP_llama3":

# filter for sensitive layers (the first 3 and last 2 layers for Llama3)
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])

# filter for non-sensitive layers (other 27 layers for Llama3)
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']))

# quantize the sensitive layers
if sensi_bit != 16:
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen)

# quantize the less-sensitive layers
if sensi_bit == 4:
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
else:
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

with torch.no_grad():

result = evaluate(
HFLM(
pretrained=model,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
get_task_dict(tasks),
limit = limit,
)

for task, res in result["results"].items():
print(f"{task}: {res}")


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.')
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers')
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers')
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default')
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on')
args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization import int8_weight_only, int4_weight_only
from torchao.quantization.quant_api import _get_linear_subclass_inserter

def intN_weight_only(group_size=32, n=8, symmetric=False):
'''
Apply int N-bit weight only quantization to a linear layer.
Args:
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
Usage:
from torchao.quantization import quantize_
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
'''
# for asymmetric quantization
def apply_intN_weight_only_quant_asym(weight):
# avoid circular dependency
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.ASYMMETRIC
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
block_size = (1, group_size)
target_dtype = torch.uint8
quant_min = 0
quant_max = 2**n-1
eps = 1e-6
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)

# for symmetric quantization
def apply_intN_weight_only_quant_sym(weight):
# avoid circular dependency
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = 1e-6
zero_point_dtype = torch.int64
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

try:
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
return int8_weight_only()
elif n == 4:
return int4_weight_only(group_size=group_size)
else:
if symmetric:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
else:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
except Exception as e:
raise

Loading