From 4664bc1a21393ecc43d8f65e0727f31d4e3694fe Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 13 Dec 2024 13:29:59 -0800 Subject: [PATCH] Add back mistakenly deleted QAT BC import test Summary: The unused imports in this test were mistakenly deleted in https://github.com/pytorch/ao/pull/1359. This commit adds them back. Test Plan: python test/quantization/test_qat.py --- test/quantization/test_qat.py | 47 +++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3a998635aa..82256ad7c8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1108,6 +1108,53 @@ def test_qat_prototype_bc(self): Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ + from torchao.quantization.prototype.qat import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + ComposableQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) + from torchao.quantization.prototype.qat._module_swap_api import ( + disable_4w_fake_quant_module_swap, + enable_4w_fake_quant_module_swap, + disable_8da4w_fake_quant_module_swap, + enable_8da4w_fake_quant_module_swap, + Int4WeightOnlyQATQuantizerModuleSwap, + Int8DynActInt4WeightQATQuantizerModuleSwap, + ) + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + to_affine_fake_quantized, + ) + from torchao.quantization.prototype.qat.api import ( + ComposableQATQuantizer, + FakeQuantizeConfig, + ) + from torchao.quantization.prototype.qat.embedding import ( + FakeQuantizedEmbedding, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyEmbedding, + Int4WeightOnlyQATEmbedding, + ) + from torchao.quantization.prototype.qat.fake_quantizer import ( + FakeQuantizer, + ) + from torchao.quantization.prototype.qat.linear import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + FakeQuantizedLinear, + Int4WeightOnlyQATLinear, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, + ) if __name__ == "__main__":