diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 6cdd9b148..1b53bf00d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -138,6 +138,40 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn ) +def _ref_change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): + """ + The deprecated implementation for int8 weight only quant API, used as a reference for + numerics and performance + """ + from torchao.quantization.quant_api import _is_linear + from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.subclass import Int8WeightOnlyQuantizedLinearWeight + + filter_fn = kwargs.pop("filter_fn", _is_linear) + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs), + filter_fn, + ) + +def _ref_change_linear_weights_to_int4_woqtensors(model, **kwargs): + """ + The deprecated implementation for int4 weight only quant API, used as a reference for + numerics and performance + """ + from torchao.quantization.quant_api import _is_linear + from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.subclass import Int4WeightOnlyQuantizedLinearWeight + + filter_fn = kwargs.pop("filter_fn", _is_linear) + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), + filter_fn, + ) + class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() @@ -489,7 +523,7 @@ def test_quantized_tensor_subclass_int4(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_tensor_subclass_int8(self): + def test_quantized_tensor_subclass_int8_wo(self): m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) @@ -501,12 +535,12 @@ def test_quantized_tensor_subclass_int8(self): # reference from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors - change_linear_weights_to_int8_woqtensors(m_copy) + _ref_change_linear_weights_to_int8_woqtensors(m_copy) res = m(*example_inputs) ref = m_copy(*example_inputs) - torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) + self.assertTrue(torch.equal(res, ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @@ -545,20 +579,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # make sure it compiles torch._export.aot_compile(m_unwrapped, example_inputs) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation") - def test_quantized_tensor_subclass_int8_dyn_quant_perf(self): + + def _test_quantized_tensor_subclass_perf(self, api, ref_api, kwargs=None): + if kwargs is None: + kwargs = {} + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_ref = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors - change_linear_weights_to_int8_dqtensors(m) + api(m, **kwargs) # reference - _ref_change_linear_weights_to_int8_dqtensors(m_ref) + ref_api(m_ref, **kwargs) res = m(*example_inputs) ref = m_ref(*example_inputs) @@ -583,7 +617,27 @@ def test_quantized_tensor_subclass_int8_dyn_quant_perf(self): print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation") + def test_quantized_tensor_subclass_int8_dyn_quant_perf(self): + from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + # @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 weight only quant implementation") + def test_quantized_tensor_subclass_int8_wo_quant_perf(self): + from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors + self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int4 weight only quant implementation") + def test_quantized_tensor_subclass_int4_wo_quant_perf(self): + kwargs = {"groupsize": 32} + from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors + self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index f660a759c..64ca9e3c7 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -574,29 +574,30 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): scale_and_zero = weight_qtensor.layout_tensor.scale_and_zero return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scale_and_zero) elif ( - is_cpu and weight_is_int8 and len(weight_qtensor.shape) == 2 and len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and weight_qtensor.block_size[1] == weight_qtensor.shape[1] and + weight_qtensor.zero_point_domain == ZeroPointDomain.INT and weight_qtensor.layout == "plain" ): # TODO: enable cpu and mps efficient path # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous() + w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t() + scale = weight_qtensor.layout_tensor.scale orig_dtype = input_tensor.dtype y = ( torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype), ) - * weight_qtensor.scale + * scale ) y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) if bias is not None: y += bias - return y.to(orig_dtype) + return y.to(orig_dtype) # is_cpu and is_mps only, some issue with is_contiguous() currently # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 04019c209..9037d4f3e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -206,13 +206,18 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): Converts all linear weight tensors to the `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. + as apply_weight_only_int8_quant while not modifying the linear modules. """ - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), - _is_linear if filter_fn is None else filter_fn, - ) + + if TORCH_VERSION_AFTER_2_4: + quantize(model, get_apply_int8wo_quant(), filter_fn) + unwrap_tensor_subclass(model, filter_fn) + else: + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), + _is_linear if filter_fn is None else filter_fn, + ) def change_linear_weights_to_int4_woqtensors(model, **kwargs): diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 583ad36f7..0c90a8501 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -14,18 +14,27 @@ input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') ## Quantization code - start +# int8 act, int8 weight dynamic quantization torchao.apply_dynamic_quant(model) -from torch._inductor import config as inductorconfig -inductorconfig.force_fuse_int_mm_with_mul = True + +# int8 weight only quantization +# torchao.quantization.change_linear_weights_to_int8_woqtensors(model) ## Quantization code - end + +## compilation configs +torch._dynamo.config.automatic_dynamic_shapes = False +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True +## compilation configs end + model = torch.compile(model, mode='max-autotune', fullgraph=True) # Must run with no_grad when optimizing for inference with torch.no_grad(): # warmup - benchmark_model(model, 5, input_tensor) + benchmark_model(model, 20, input_tensor) # benchmark - print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds") + print("elapsed_time: ", benchmark_model(model, 1000, input_tensor), " milliseconds") # Create a trace profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)