Skip to content

Commit

Permalink
Rebase and lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 13, 2024
1 parent 82477b9 commit 843716d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
Expand Down
10 changes: 5 additions & 5 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op


# # following are a list of (dispatch_condition, implementation) functions that takes the following args:
# # input_tensor: dimension is (M1, M2, ..., in_features)
# # weight_tensor: dimension is (out_features, in_features)
# # bias: dimension is (out_features,)
# # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args:
# input_tensor: dimension is (M1, M2, ..., in_features)
# weight_tensor: dimension is (out_features, in_features)
# bias: dimension is (out_features,)
# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
register_layout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
int_scaled_matmul,
)
from torchao.utils import fill_defaults

Expand Down

0 comments on commit 843716d

Please sign in to comment.