diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d923df18d..4bbb87ece 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -140,6 +140,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.zero_point_domain, output_dtype=output_dtype, ) + from torchao.dtypes.uintx import TensorCoreTiledLayout + if isinstance(self._layout, TensorCoreTiledLayout): # need to return to original shape if tensor was padded # in preprocessing diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 514b90930..fe7644d92 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -101,11 +101,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op -# # following are a list of (dispatch_condition, implementation) functions that takes the following args: -# # input_tensor: dimension is (M1, M2, ..., in_features) -# # weight_tensor: dimension is (out_features, in_features) -# # bias: dimension is (out_features,) -# # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches +# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args: +# input_tensor: dimension is (M1, M2, ..., in_features) +# weight_tensor: dimension is (out_features, in_features) +# bias: dimension is (out_features,) +# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index cad0aafcc..ed171634c 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -11,9 +11,11 @@ register_layout, ) from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout +from torchao.kernel import ( + int_scaled_matmul, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, - int_scaled_matmul, ) from torchao.utils import fill_defaults