diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6_llm.py similarity index 82% rename from benchmarks/benchmark_fp6.py rename to benchmarks/benchmark_fp6_llm.py index f4d7f2269..b6b99c0eb 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6_llm.py @@ -1,16 +1,16 @@ import torch import pandas as pd import torch.nn.functional as F -# from torchao.prototype.quant_llm import QuantLlmLinearWeight -from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType +from torchao.prototype.quant_llm import QuantLlmLinearWeight from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2)) + fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") + scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 + fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2) + fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index db7a6a9b7..82f867d96 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -20,7 +20,6 @@ int8_dynamic_activation_int8_weight, quantize_, autoquant, - fpx_weight_only, ) from torchao.sparsity import ( sparsify_, @@ -60,8 +59,6 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars elif quantization == "int4wo": # note cannot quantize this model on cpu and run it on cuda at this time quantize_(model.to(device=device), int4_weight_only()) - elif quantization == "fp6": - quantize_(model, fpx_weight_only(3, 2)) elif quantization == "autoquant": model = autoquant(model.to(device=device)) @@ -82,7 +79,7 @@ def all_linear(mod, name): return False torch.sparse.semi_structured._FORCE_CUTLASS = False sparsify_(model, semi_sparse_weight(), filter_fn=all_linear) - + if sparsity and compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -114,7 +111,7 @@ def all_linear(mod, name): parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--save', action='store_true', help='Whether to save the model.') diff --git a/test/dtypes/test_fpx.py b/test/prototype/test_quant_llm.py similarity index 60% rename from test/dtypes/test_fpx.py rename to test/prototype/test_quant_llm.py index 130bdadf3..610979674 100644 --- a/test/dtypes/test_fpx.py +++ b/test/prototype/test_quant_llm.py @@ -8,27 +8,23 @@ parametrize, run_tests, ) -from torchao.dtypes.fpx import ( - FpxTensorCoreAQTLayout, - FpxTensorCoreLayoutType, +from torchao.prototype.quant_llm import ( + QuantLlmLinearWeight, + quant_llm_fpx_weight_only, + fp6_llm_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx, ) -from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6 +from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 -from torchao.quantization import ( - quantize_, - fpx_weight_only, -) - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.quantization.quant_api import quantize_ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _FPx_DTYPES = [(3, 2), (2, 2)] -class TestFpxTensorCoreAQTLayout(TestCase): +class TestQuantLlmLinearWeight(TestCase): @parametrize("device", _DEVICES) def test_pack_tc_fp6_correctness(self, device): x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) @@ -73,40 +69,61 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", _FPx_DTYPES) def test_to_copy_device(self, ebits, mbits): - from torchao.quantization.quant_primitives import ( - choose_qparams_affine_fpx, - quantize_affine_fpx, - ) - x = torch.randn(256, 64) - scale = choose_qparams_affine_fpx(x, ebits, mbits) - x = quantize_affine_fpx(x, scale, ebits, mbits) - layout_type = FpxTensorCoreLayoutType(ebits, mbits) - fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() - assert fpx_layout_tensor.device.type == "cuda" - fpx_layout_tensor = fpx_layout_tensor.cpu() - assert fpx_layout_tensor.device.type == "cpu" + fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda() + assert fpx.device.type == "cuda" + fpx = fpx.cpu() + assert fpx.device.type == "cpu" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("leading_dims", [(4,), (2, 4)]) + @parametrize("bias", [False, True]) + def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims): + OC, IC = 256, 64 + device = "cuda" + + fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half) + fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None + + fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits) + + x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) + out = torch.nn.functional.linear(x, fpx_weight, fp16_bias) + assert out.shape == leading_dims + (OC,) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("bias", [False, True]) - def test_fpx_weight_only(self, ebits, mbits, bias): + def test_quant_llm_quantize(self, ebits, mbits, bias): + N, OC, IC = 4, 256, 64 + device = "cuda" + + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fpx_linear = copy.deepcopy(linear) + quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + + x = torch.randn(N, IC, device=device, dtype=torch.half) + expected = fpx_linear(x) + actual = torch.compile(fpx_linear, fullgraph=True)(x) + torch.testing.assert_close(actual, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp6_llm_quantize(self): N, OC, IC = 4, 256, 64 device = "cuda" - linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half) + linear = torch.nn.Linear(IC, OC, device=device) fpx_linear = copy.deepcopy(linear) - quantize_(fpx_linear, fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, fp6_llm_weight_only()) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fpx_linear(x) actual = torch.compile(fpx_linear, fullgraph=True)(x) - # somehow compile now changes the result a bit torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout) +instantiate_parametrized_tests(TestQuantLlmLinearWeight) if __name__ == "__main__": diff --git a/test/test_ops.py b/test/test_ops.py index eb22f40ad..171089237 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ ) from torch.testing._internal.optests import opcheck from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff -from torchao.dtypes.fpx import from_scaled_tc_fpx +from torchao.prototype.quant_llm import from_scaled_tc_fpx from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 import pytest @@ -318,7 +318,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] MARLIN_TEST_PARAMS = list(itertools.product( - MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, + MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS )) @@ -399,7 +399,7 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): workspace_24 = marlin_24_workspace(size_n) fn_inputs = ( - a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, + a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1], ) output = torchao.ops.marlin_24_gemm(*fn_inputs) diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 70f4a411b..858196776 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -13,222 +13,223 @@ from torchao.quantization.utils import _lm_eval_available, _MultiInput -import lm_eval -try: # lm_eval version 0.4 - from lm_eval.evaluator import evaluate # pyre-ignore[21] - from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21] - from lm_eval.tasks import get_task_dict # pyre-ignore[21] -except: # lm_eval version 0.3 - from lm_eval import base, evaluator, tasks - - eval_wrapper = base.BaseLM - get_task_dict = tasks.get_task_dict - evaluate = evaluator.evaluate - -class InputRecorder(eval_wrapper): - """ - This is a fake evaluation wrapper from the lm_eval library that just records the inputs - so that they can be used in calibration. - - If pad_calibration_inputs is enabled, the input recorder will take - each input and pad/truncate it down to the calibration_seq_length. - (if using padding you should set the embeddings for the pad_token to 0 - in the model) - - Note: after padding/truncation, input_prep_function is called to bring - it to the proper form to be inserted into a given model. - - If not, it will only truncate inputs to the desired length. - """ - - def __init__( - self, - tokenizer, - calibration_seq_length, - input_prep_func=None, - pad_calibration_inputs=False, - vocab_size=32000, - pad_token=0, - device="cpu", - ): - try: - super().__init__() - except TypeError: - # lm_eval 0.4.2 removed the default init - super().__init__("gpt2", device="cpu") - - self.tokenizer = tokenizer - self._device = torch.device(device) - self.vocab_size = vocab_size - self._max_seq_length = calibration_seq_length - self.calibration_seq_length = calibration_seq_length - - # need to take inps and convert to corrent input - # for model - self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) - ) - - self.pad_calibration_inputs = pad_calibration_inputs - self.pad_token = pad_token - - self.inputs = None - - @property - def eot_token_id(self): - try: - return self.tokenizer.eos_id() - except: - return self.tokenizer.eos_id - - @property - def max_length(self): - return self._max_seq_length - - @property - def max_gen_toks(self): - return 50 - - @property - def batch_size(self): - return 1 - - @property - def device(self): - return self._device - - def tok_encode(self, string: str, **kwargs): - # TODO: verify this for multi-batch as well - tokens = self.tokenizer.encode(string) - if hasattr(self.tokenizer, "bos_id"): +if _lm_eval_available: + import lm_eval + try: # lm_eval version 0.4 + from lm_eval.evaluator import evaluate # pyre-ignore[21] + from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21] + from lm_eval.tasks import get_task_dict # pyre-ignore[21] + except: # lm_eval version 0.3 + from lm_eval import base, evaluator, tasks + + eval_wrapper = base.BaseLM + get_task_dict = tasks.get_task_dict + evaluate = evaluator.evaluate + + class InputRecorder(eval_wrapper): + """ + This is a fake evaluation wrapper from the lm_eval library that just records the inputs + so that they can be used in calibration. + + If pad_calibration_inputs is enabled, the input recorder will take + each input and pad/truncate it down to the calibration_seq_length. + (if using padding you should set the embeddings for the pad_token to 0 + in the model) + + Note: after padding/truncation, input_prep_function is called to bring + it to the proper form to be inserted into a given model. + + If not, it will only truncate inputs to the desired length. + """ + + def __init__( + self, + tokenizer, + calibration_seq_length, + input_prep_func=None, + pad_calibration_inputs=False, + vocab_size=32000, + pad_token=0, + device="cpu", + ): + try: + super().__init__() + except TypeError: + # lm_eval 0.4.2 removed the default init + super().__init__("gpt2", device="cpu") + + self.tokenizer = tokenizer + self._device = torch.device(device) + self.vocab_size = vocab_size + self._max_seq_length = calibration_seq_length + self.calibration_seq_length = calibration_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + self.pad_calibration_inputs = pad_calibration_inputs + self.pad_token = pad_token + + self.inputs = None + + @property + def eot_token_id(self): try: - tokens = [self.tokenizer.bos_id()] + tokens + return self.tokenizer.eos_id() except: - tokens = [self.tokenizer.bos_id] + tokens - return tokens - - def tok_decode(self, tokens): - decoded = self.tokenizer.decode(tokens) - return decoded - - def add_input(self, args): - if self.inputs is None: - self.inputs = [_MultiInput([arg]) for arg in args] - else: - self.inputs = [ - multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) - ] - - def record_inputs( - self, - calibration_tasks, - calibration_limit, - ): - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - - evaluate( + return self.tokenizer.eos_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def max_gen_toks(self): + return 50 + + @property + def batch_size(self): + return 1 + + @property + def device(self): + return self._device + + def tok_encode(self, string: str, **kwargs): + # TODO: verify this for multi-batch as well + tokens = self.tokenizer.encode(string) + if hasattr(self.tokenizer, "bos_id"): + try: + tokens = [self.tokenizer.bos_id()] + tokens + except: + tokens = [self.tokenizer.bos_id] + tokens + return tokens + + def tok_decode(self, tokens): + decoded = self.tokenizer.decode(tokens) + return decoded + + def add_input(self, args): + if self.inputs is None: + self.inputs = [_MultiInput([arg]) for arg in args] + else: + self.inputs = [ + multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) + ] + + def record_inputs( self, - task_dict, - limit=calibration_limit, - ) - return self - - def get_inputs(self): - return self.inputs - - def _model_call(self, inps): - inps = inps.squeeze(0) - T = len(inps) - if ( - # can't use inputs that are too short when padding disabled - (T < self.calibration_seq_length and not self.pad_calibration_inputs) - or - # can't use inputs that actually use token we use for padding - (self.pad_calibration_inputs and self.pad_token in inps) + calibration_tasks, + calibration_limit, ): - # give random output + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + self, + task_dict, + limit=calibration_limit, + ) + return self + + def get_inputs(self): + return self.inputs + + def _model_call(self, inps): + inps = inps.squeeze(0) + T = len(inps) + if ( + # can't use inputs that are too short when padding disabled + (T < self.calibration_seq_length and not self.pad_calibration_inputs) + or + # can't use inputs that actually use token we use for padding + (self.pad_calibration_inputs and self.pad_token in inps) + ): + # give random output + return torch.randn( + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device + ) + + # pad or truncate to the right size + if T >= self.calibration_seq_length: + inps = inps[: self.calibration_seq_length] + else: + inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + + inps = inps.unsqueeze(0) + model_in = self.input_prep_func(inps) + + self.add_input(model_in) + + # output `something` with correct shape to keep eval going return torch.randn( (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device ) - # pad or truncate to the right size - if T >= self.calibration_seq_length: - inps = inps[: self.calibration_seq_length] - else: - inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) - - inps = inps.unsqueeze(0) - model_in = self.input_prep_func(inps) - - self.add_input(model_in) - - # output `something` with correct shape to keep eval going - return torch.randn( - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device - ) - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - -class TransformerEvalWrapper(InputRecorder): - """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. - """ - def __init__( - self, - model, - tokenizer, - max_seq_length, - input_prep_func=None, - device="cuda" - ): - super().__init__(tokenizer, None) - self._model = model - # self.tokenizer = tokenizer - self._device = torch.device(device) - self._max_seq_length = max_seq_length - - # need to take inps and convert to corrent input - # for model - self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) - ) - - def _model_call(self, inps): - # TODO: make batches work - input = self.input_prep_func(inps) - - max_seq_length = min(max(inps.size()), self.max_length) - with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) - logits = self._model(*input) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') - - def run_eval(self, tasks, limit): - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(tasks) - print("Evaluating Model On: ", task_dict) - with torch.no_grad(): - result = evaluate( - self, - task_dict, - limit=limit, + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + class TransformerEvalWrapper(InputRecorder): + """ + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + """ + def __init__( + self, + model, + tokenizer, + max_seq_length, + input_prep_func=None, + device="cuda" + ): + super().__init__(tokenizer, None) + self._model = model + # self.tokenizer = tokenizer + self._device = torch.device(device) + self._max_seq_length = max_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) ) - for task, res in result["results"].items(): - print(f"{task}: {res}") - return result + + def _model_call(self, inps): + # TODO: make batches work + input = self.input_prep_func(inps) + + max_seq_length = min(max(inps.size()), self.max_length) + with torch.device(self._device): + self._model.setup_caches(self.batch_size, max_seq_length) + logits = self._model(*input) + return logits + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception('unimplemented') + + def run_eval(self, tasks, limit): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + print("Evaluating Model On: ", task_dict) + with torch.no_grad(): + result = evaluate( + self, + task_dict, + limit=limit, + ) + for task, res in result["results"].items(): + print(f"{task}: {res}") + return result diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f673a966d..fc8634dd0 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,12 +13,8 @@ ) from torchao.quantization.quant_api import ( - quantize_, - int4_weight_only, - int8_weight_only, - int8_dynamic_activation_int8_weight, - fpx_weight_only, - unwrap_tensor_subclass, + quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass + ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -72,8 +68,6 @@ def run_evaluation( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model.to(device), int4_weight_only(group_size=groupsize)) - if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 94a18488b..20a2f401f 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -207,10 +207,10 @@ def main( int8_weight_only, int8_dynamic_activation_int8_weight, int4_weight_only, - fpx_weight_only, autoquant, unwrap_tensor_subclass - ) + ) + if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: @@ -219,8 +219,6 @@ def main( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) - if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) if "autoquant" == quantization: model = autoquant(model, manual=True) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 80c15ae43..e4b47b822 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,7 +4,6 @@ from .affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized, - to_affine_quantized_fpx, to_affine_quantized_static, LayoutType, PlainLayoutType, @@ -18,7 +17,6 @@ "UInt4Tensor" "AffineQuantizedTensor", "to_affine_quantized", - "to_affine_quantized_fpx", "to_affine_quantized_static", "LayoutType", "PlainLayoutType", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 7491b0ecf..6c36d98c4 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -11,9 +11,9 @@ MappingType, int_scaled_matmul, quantize_affine_hqq, - choose_qparams_affine_fpx, - quantize_affine_fpx, - dequantize_affine_fpx, +) +from torchao.quantization.utils import ( + pack_tinygemm_scales_and_zeros, ) from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import ( @@ -33,7 +33,6 @@ TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, ) -from torchao.ops import quant_llm_linear aten = torch.ops.aten @@ -46,11 +45,6 @@ class AQTLayout(TorchAOBaseTensor): Base class for the layout tensor for `AffineQuantizedTensor` """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the layout Tensor - - Returns int_data, scale and zero_point - Can be overwritten if other types of AQTLayout Tensor has different numbers of plain tensors - """ pass def get_layout_type(self) -> LayoutType: @@ -162,14 +156,8 @@ def __repr__(self): def dequantize(self, output_dtype=None): if output_dtype is None: output_dtype = self.dtype - - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - if isinstance(self.layout_type, FpxTensorCoreLayoutType): - int_data, scale = self.layout_tensor.get_plain() - return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) - else: - int_data, scale, zero_point = self.layout_tensor.get_plain() - return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + int_data, scale, zero_point = self.layout_tensor.get_plain() + return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): @@ -277,35 +265,6 @@ def from_float_static( dtype=input_float.dtype, ) - @classmethod - def from_float_fpx( - cls, - input_float: torch.Tensor, - layout_type: LayoutType = PlainLayoutType() - ): - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - assert isinstance(layout_type, FpxTensorCoreLayoutType), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" - original_shape = input_float.shape - input_float = layout_type.pre_process(input_float) - # per axis quantization, where axis = 1 - block_size = list(input_float.shape) - block_size[1] = 1 - - ebits, mbits = layout_type.ebits, layout_type.mbits - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_fpx(input_float, ebits, mbits) - fpx_unpacked = quantize_affine_fpx(input_float, scale, ebits, mbits) - fpx_packed = layout_type.post_process(fpx_unpacked) - - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(fpx_packed, scale, None, layout_type) - return cls( - layout_tensor, - block_size, - original_shape, - dtype=input_float.dtype - ) - @property def layout_type(self) -> LayoutType: return self.layout_tensor.layout_type @@ -496,7 +455,7 @@ def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: Optional[torch.Tensor], + zero_point: torch.Tensor, layout_type: LayoutType, ): assert isinstance(layout_type, PlainLayoutType) @@ -535,7 +494,7 @@ def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: Optional[torch.Tensor], + zero_point: torch.Tensor, layout_type: LayoutType, ): assert isinstance(layout_type, SemiSparseLayoutType) @@ -600,7 +559,7 @@ def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: Optional[torch.Tensor], + zero_point: torch.Tensor, layout_type: LayoutType ): @@ -614,7 +573,6 @@ def from_plain( packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, layout_type) @@ -906,55 +864,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y -def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - return ( - # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype == torch.float16 and - # weight is fpx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) and - ( - # weight is using fp6 quantization - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 3) or - # weight is using fp5 quantization - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 1) - ) - ) - -def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.dtypes.fpx import _SPLIT_K_MAP - act = input_tensor - weight = weight_tensor - - out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim).half() - - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - bsize = act_reshaped.shape[0] - splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - out = quant_llm_linear( - weight.layout_type.ebits, - weight.layout_type.mbits, - act_reshaped, - weight.layout_tensor.packed_fpx_data, - weight.layout_tensor.scale, - splitK=splitK, - ) - - if bias is not None: - out += bias - - return out.view(*act.shape[:-1], out_dim).to(act.dtype) def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ @@ -963,7 +872,6 @@ def _register_quantized_linear_dispatches(): (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), ]: _register_quantized_linear_dispatch(dispatch_condition, impl) @@ -1071,7 +979,6 @@ def _(func, types, args, kwargs): to_affine_quantized = AffineQuantizedTensor.from_float to_affine_quantized_static = AffineQuantizedTensor.from_float_static -to_affine_quantized_fpx = AffineQuantizedTensor.from_float_fpx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/fpx/__init__.py b/torchao/dtypes/fpx/__init__.py deleted file mode 100644 index af77685fa..000000000 --- a/torchao/dtypes/fpx/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP diff --git a/torchao/dtypes/fpx/README.md b/torchao/prototype/quant_llm/README.md similarity index 89% rename from torchao/dtypes/fpx/README.md rename to torchao/prototype/quant_llm/README.md index 1de60b0cc..f0ecd38d5 100644 --- a/torchao/dtypes/fpx/README.md +++ b/torchao/prototype/quant_llm/README.md @@ -5,17 +5,15 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F ## Usage ```python -from torchao.quantization import ( - quantize_, - fpx_weight_only, -) +from torchao.quantization.quant_api import quantize_ +from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only model = ... model.half() # not necessary, but recommeneded to maintain accuracy +quantize_(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place # for generic FPx EyMz where x = 1 + y + z -# fp6 with ebits = 3 and mbits = 2 -quantize_(model, fpx_weight_only(3, 2)) +# quantize_(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) @@ -25,7 +23,7 @@ It's also possible to pre-process the weight and call the kernel directly. ```python import torch -from torchao.dtypes.fpx import to_scaled_tc_fpx +from torchao.prototype.quant_llm import to_scaled_tc_fpx from torchao.ops import quant_llm_linear fp32_weight = torch.randn(1024, 512).cuda() diff --git a/torchao/prototype/quant_llm/__init__.py b/torchao/prototype/quant_llm/__init__.py new file mode 100644 index 000000000..4f1479c40 --- /dev/null +++ b/torchao/prototype/quant_llm/__init__.py @@ -0,0 +1 @@ +from .quant_llm import QuantLlmLinearWeight, fp6_llm_weight_only, quant_llm_fpx_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/prototype/quant_llm/quant_llm.py similarity index 64% rename from torchao/dtypes/fpx/fpx.py rename to torchao/prototype/quant_llm/quant_llm.py index 00ee84b65..f41bac9b2 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -1,21 +1,13 @@ from functools import reduce -from typing import Tuple, Optional +from typing import Tuple import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.ops import quant_llm_linear -from torchao.dtypes.utils import ( - LayoutType, - _implements, - _dispatch__torch_function__, - _dispatch__torch_dispatch__, -) +from torchao.dtypes.utils import _implements, _dispatch__torch_function__, _dispatch__torch_dispatch__ from torchao.quantization.quant_api import _get_linear_subclass_inserter -from dataclasses import dataclass -from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls -from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -355,143 +347,119 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te ] -# quantization api integrations - -@dataclass(frozen=True) -class FpxTensorCoreLayoutType(LayoutType): - """Layout type for FpxTensorCoreAQTLayout - """ - ebits: int - mbits: int - -@register_layout_cls(FpxTensorCoreLayoutType) -class FpxTensorCoreAQTLayout(AQTLayout): - """FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b), - it has a internal tensor field of "packed_fpx_data", which is packed from the - uint8 unpacked data (the output of `quantize_affine_fpx` operator) - - The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 - github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - - At a high level packing is done by grouping bits into 1 bit fragments (shards), 2 bit fragments and - 4 bit fragments each fragments are packed separately and concatenated together. - For example for 6 bit dtype, we can extract the first 4 bits for all elements and pack them together - in a fragment, and extract the last 2 bits for all elements and pack them into fragment, in the end - we concatenate the fragments together. - - If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be - (M, N // 8 * nbit) - - FpxTensorCoreAQTLayout.from_plain takes an unpacked uint8 fpx Tensor of shape (M, N), with format of - (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor - FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit) - """ - def __new__( - cls, - packed_fpx_data: torch.Tensor, - scale: torch.Tensor, - layout_type: LayoutType, - ): - assert packed_fpx_data.ndim == 2 - assert packed_fpx_data.dtype == torch.uint8 - shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) - kwargs = {} - kwargs["device"] = packed_fpx_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout +class QuantLlmLinearWeight(Tensor): + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + + @staticmethod + def __new__(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + assert fpx_data.ndim == 2 + assert fpx_data.dtype == torch.uint8 + shape = (fpx_data.shape[0], fpx_data.shape[1] // (1 + ebits + mbits) * 8) + + return Tensor._make_wrapper_subclass( + cls, + shape, + device=fpx_data.device, + requires_grad=False, ) - kwargs["dtype"] = packed_fpx_data.dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_fpx_data: torch.Tensor, - scale: torch.Tensor, - layout_type: LayoutType, - ): - self.packed_fpx_data = packed_fpx_data + + def __init__(self, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + self.fpx_data = fpx_data self.scale = scale - self.layout_type = layout_type + self.ebits = ebits + self.mbits = mbits def __tensor_flatten__(self): - return ["packed_fpx_data", "scale"], [self.layout_type] + return ["fpx_data", "scale"], [self.ebits, self.mbits] @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"] - layout_type, = tensor_attributes - return cls(packed_fpx_data, scale, layout_type) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) - return unpacked_fpx_data, self.scale + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["fpx_data"], tensor_data_dict["scale"], *tensor_attributes) @classmethod - def from_plain( - cls, - unpacked_fpx_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - layout_type: LayoutType, - ): - """ - Format for `unpacked_fpx_data` will be: - zero padding bits | sign bit | exponent bits | mantissa bits - - For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent - bit, M is mantissa bit - """ - assert isinstance(layout_type, FpxTensorCoreLayoutType) - packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits) - return cls(packed_fpx_data, scale, layout_type) + def from_float(cls, input_float: Tensor, ebits: int, mbits: int): + fpx_data, scale = to_scaled_tc_fpx(input_float, ebits, mbits) + return cls(fpx_data, scale, ebits, mbits) + + def dequantize(self, output_dtype=None): + output_dtype = output_dtype or torch.get_default_dtype() + return from_scaled_tc_fpx(self.fpx_data, self.ebits, self.mbits, self.scale).to(output_dtype) def __repr__(self): - unpacked_fpx_data, scale = self.get_plain() - layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(unpacked_fpx_data={unpacked_fpx_data}, scale={scale}, layout_type={layout_type})" + dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" + return ( + f"{self.__class__.__name__}(dtype={dtype}, shape={self.shape}, " + f"device={self.device}, requires_grad={self.requires_grad})" + ) def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.packed_fpx_data), + fn(self.fpx_data), fn(self.scale), - self.layout_type, + self.ebits, + self.mbits, ) - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.packed_fpx_data.to(device), - self.scale.to(device), - self.layout_type, - ) +@QuantLlmLinearWeight.implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + act = args[0] + weight = args[1] + bias = args[2] if len(args) >= 3 else None + assert isinstance(weight, QuantLlmLinearWeight) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten._to_copy.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), - ) - - raise NotImplementedError( - f"FpxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" - ) + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim).half() + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight.ebits, + weight.mbits, + act_reshaped, + weight.fpx_data, + weight.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) + + +@QuantLlmLinearWeight.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + + +@QuantLlmLinearWeight.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)) + + +@QuantLlmLinearWeight.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # only support device kwargs, ignore the rest + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + ) + + +def quant_llm_fpx_weight_only(ebits: int, mbits: int): + def apply_quant_llm(weight: Tensor) -> Tensor: + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + return weight + return QuantLlmLinearWeight.from_float(weight, ebits, mbits) + return _get_linear_subclass_inserter(apply_quant_llm) - __torch_function__ = torch._C._disabled_torch_function_impl - def get_layout_type(self) -> LayoutType: - return self.layout_type +def fp6_llm_weight_only(): + return quant_llm_fpx_weight_only(3, 2) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 75e762ce3..2ac4a0c28 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -39,8 +39,6 @@ "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", - "uintx_weight_only", - "fpx_weight_only", "LinearActivationQuantizedTensor", "to_linear_activation_quantized", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bb997b043..a1c4bb5d9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,6 +21,14 @@ import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional +from torchao.dtypes.uintx.Uintx import UintxLayoutType +from torchao.dtypes import ( + to_affine_quantized, + TensorCoreTiledLayoutType, + PlainLayoutType, + AffineQuantizedTensor, + SemiSparseLayoutType +) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, @@ -29,7 +37,11 @@ from .subclass import ( QuantizedLinearWeightBase, ) -from torchao.dtypes import TensorCoreTiledLayoutType + +from .linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, + to_linear_activation_quantized, +) from .quant_primitives import ( MappingType, @@ -60,8 +72,6 @@ "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", - "uintx_weight_only", - "fpx_weight_only", ] from .GPTQ import ( @@ -192,10 +202,6 @@ def _is_linear(mod, *args): # adding weight tensor subclass isinstance check to make sure the weight is only quantized once # when it is shared by multiple linear modules - from torchao.dtypes import AffineQuantizedTensor - from .linear_activation_quantized_tensor import ( - LinearActivationQuantizedTensor, - ) return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") @@ -347,15 +353,11 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int8 return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype) def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32): - from torchao.dtypes import to_affine_quantized - from .linear_activation_quantized_tensor import to_linear_activation_quantized - if weight.shape[-1] % group_size != 0: return weight @@ -391,7 +393,6 @@ def insert_subclass(lin): def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): -# def int4_weight_only(group_size=128, layout_type=None): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -409,14 +410,7 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner size is more fine grained, choices are [256, 128, 64, 32] `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` """ - from torchao.dtypes import TensorCoreTiledLayoutType - - if layout_type is None: - layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) - def apply_int4_weight_only_quant(weight, use_hqq=False): - from torchao.dtypes import to_affine_quantized - if weight.shape[-1] % group_size != 0: return weight @@ -439,8 +433,6 @@ def int8_weight_only(): Applies int8 weight-only symmetric per-channel quantization to linear layers. """ def apply_int8wo_quant(weight): - from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -451,8 +443,6 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = 1e-5 @@ -461,19 +451,12 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) -def int8_dynamic_activation_int8_weight(layout_type=None): +def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - from torchao.dtypes import PlainLayoutType - if layout_type is None: - layout_type = PlainLayoutType() - def apply_int8_dynamic_activation_int8_weight_quant(weight): - from torchao.dtypes import to_affine_quantized - from .linear_activation_quantized_tensor import to_linear_activation_quantized - in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: @@ -503,7 +486,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - from torchao.dtypes import SemiSparseLayoutType return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) @@ -523,10 +505,8 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1): ZeroPointDomain, ) from torchao.quantization.quant_api import _get_linear_subclass_inserter - from torchao.dtypes import to_affine_quantized def apply_uintx_weight_only_quant(weight): - from torchao.dtypes.uintx.Uintx import UintxLayoutType layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -543,28 +523,6 @@ def apply_uintx_weight_only_quant(weight): return _get_linear_subclass_inserter(apply_uintx_weight_only_quant) -def fpx_weight_only(ebits: int, mbits: int): - """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits - e.g. fp6_e3_m2, fp6_e2_m3, ... - - The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 - github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - - For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTLayout` - """ - def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - from torchao.dtypes import to_affine_quantized_fpx - - assert weight.dim() == 2, f"fpx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - return weight - - layout_type = FpxTensorCoreLayoutType(ebits, mbits) - return to_affine_quantized_fpx(weight, layout_type) - return _get_linear_subclass_inserter(apply_quant_llm) - if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 7cac70440..3b73fd7fe 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -16,7 +16,6 @@ TORCH_VERSION_AT_LEAST_2_5, ) from torchao.utils import _register_custom_op -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones __all__ = [ @@ -24,11 +23,8 @@ "int_scaled_matmul", "choose_qparams_affine", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_fpx", "quantize_affine", "dequantize_affine", - "quantize_affine_fpx", - "dequantize_affine_fpx", "fake_quantize_affine", "fake_quantize_affine_cachemask", "quantize_affine_hqq", @@ -89,8 +85,6 @@ class ZeroPointDomain(Enum): ) -_ONES_TABLE = [_n_ones(i) for i in range(8)] - quant_lib = torch.library.Library("quant", "FRAGMENT") register_custom_op = _register_custom_op(quant_lib) @@ -704,6 +698,7 @@ def _choose_qparams_affine( return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) + #HQQ ############################################################################ # Shrinking operator (proximal operator for the lp norm) @@ -871,32 +866,3 @@ def quantize_affine_hqq( torch.cuda.empty_cache() return W_q, scale, zero, shape - - -def choose_qparams_affine_fpx(tensor: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: - # _n_ones() is not compatible with torch.compile() due to << operator - # https://github.com/pytorch/pytorch/issues/119152 - # exp_bias = _n_ones(ebits - 1) - # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) - - # workaround: global lookup table - exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) - - tensor = tensor.float() - scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - return scale.half() - -def quantize_affine_fpx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: - """Quantizes the float32 high precision floating point tensor to low precision floating point number and - converts the result to unpacked floating point format with the format of 00SEEEMM (for fp6_e3m2) where S means sign bit, e means exponent bit and m means mantissa bit - """ - tensor = tensor.float() - tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) - return tensor_fpx - -def dequantize_affine_fpx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int, output_dtype: torch.dtype = torch.float32) -> torch.Tensor: - tensor = _fpx_unpacked_to_f32(tensor, ebits, mbits) - tensor = tensor * scale.float().view(-1, 1) - tensor = tensor.to(dtype=output_dtype) - return tensor diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index ae4f48d9d..99ad0a4f6 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,7 +9,7 @@ from torch.utils._python_dispatch import TorchDispatchMode import torch.nn.utils.parametrize as parametrize from torchao.utils import find_multiple -from torchao.quantization.quant_primitives import ( +from .quant_primitives import ( MappingType, ZeroPointDomain, choose_qparams_affine,