Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an error in subclass impl #226

Merged
merged 1 commit into from
May 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].device == torch.device("cpu")
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
Expand Down Expand Up @@ -803,14 +796,21 @@ def _apply_fn_to_data(self, fn):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
and args[0].device == torch.device("cpu")
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
Expand All @@ -833,6 +833,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
None if len(args) == 2 else args[2],
)
if weight_qtensor.input_quant_func is not None:
# dynamic quantization
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
Expand Down
Loading