Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import patch

import torch
import torch.nn.functional as F

from torchao.testing.utils import skip_if_no_cuda
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
Expand Down Expand Up @@ -344,6 +345,53 @@ def __init__(
)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

def test_implements_and_torch_function_together(self):
"""Ensure a function decorated with both @_implements and @_implements_torch_function works."""
counter = {"calls": 0}

class MyTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr", "device"]

def __new__(cls, qdata: torch.Tensor, attr: str = "attr", device=None):
kwargs = {}
if device is None:
device = qdata.device
kwargs["device"] = device
kwargs["dtype"] = qdata.dtype
r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)
r.qdata = qdata
r.attr = attr
return r

def __init__(self, qdata: torch.Tensor, attr: str = "attr", device=None):
pass

implements = MyTensor.implements
implements_torch_function = MyTensor.implements_torch_function

@implements([torch.ops.aten.t.default])
@implements_torch_function([F.linear])
def fake_linear(func, types, args, kwargs):
counter["calls"] += 1

l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(MyTensor(l.weight.detach(), "attr", None))
x = torch.randn(4, 2)

# Torch function path
F.linear(x, l.weight, l.bias)
self.assertEqual(
counter["calls"], 1, "Expected fake_linear to be called via F.linear"
)

# ATen path
mt = MyTensor(torch.randn(3, 4))
torch.ops.aten.t.default(mt)
self.assertEqual(
counter["calls"], 2, "Expected fake_linear to be called via aten.t.default"
)


if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,11 @@ def _register_aqt_quantized_linear_dispatches():
_register_aqt_quantized_linear_dispatches()

implements = AffineQuantizedTensor.implements
implements_torch_function = AffineQuantizedTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: btw, we don't need a list if it's a single op, i.e. we can do @implements(aten.linear.default)

@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down Expand Up @@ -296,7 +298,7 @@ def _(func, types, args, kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(torch.nn.functional.embedding)
@implements_torch_function(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
if _embedding_q_dq_check(args, kwargs):
return _embedding_q_dq_impl(args, kwargs)
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ def from_float(cls, weight):
return cls(weight)


@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
@Float32Tensor.implements(aten.linear.default)
@Float32Tensor.implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def to(self, *args, **kwargs):


implements = CodebookQuantizedTensor.implements
implements_torch_function = CodebookQuantizedTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand All @@ -177,7 +179,8 @@ def _(func, types, args, kwargs):
return func(input_tensor, weight_tensor, bias)


@implements([torch.nn.functional.embedding, aten.embedding.default])
@implements([aten.embedding.default])
@implements_torch_function([torch.nn.functional.embedding])
def _(func, types, args, kwargs):
assert len(args) == 2
indices, weight_tensor = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def from_codebook_quantized_tensor(


implements = CodebookQuantizedPackedTensor.implements
implements_torch_function = CodebookQuantizedPackedTensor.implements_torch_function


@implements([F.linear])
@implements_torch_function(F.linear)
def _(func, types, args, kwargs):
"""
Override for `torch.nn.functional.linear` specifically for the
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/quantization/gguf/gguf_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype):


implements = GGUFQuantizedTensor.implements
implements_torch_function = GGUFQuantizedTensor.implements_torch_function


@implements([aten.detach.default, aten.alias.default])
Expand All @@ -244,7 +245,8 @@ def _(func, types, args, kwargs):
)


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def from_intx_unpacked_to_int8_tensor(


implements = Int8LutTensor.implements
implements_torch_function = Int8LutTensor.implements_torch_function


def _linear_impl_2d(
Expand Down Expand Up @@ -202,7 +203,8 @@ def _linear_impl_2d(
return res


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/quantized_training/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def fsdp_post_all_gather(
return BitNetPacked2bitLinearWeight(data_i2, scale), all_gather_outputs


@BitNetTrainingLinearWeight.implements(F.linear)
@BitNetTrainingLinearWeight.implements_torch_function(F.linear)
def _(func, types, args, kwargs):
if torch.is_autocast_enabled("cuda"):
dtype = torch.get_autocast_gpu_dtype()
Expand Down Expand Up @@ -324,7 +324,7 @@ def dequantize(self, out_dtype=None):
return out


@BitNetPacked2bitLinearWeight.implements(F.linear)
@BitNetPacked2bitLinearWeight.implements_torch_function(F.linear)
def _(func, types, args, kwargs):
return _BitNetPacked2bitLinear.apply(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/quantized_training/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ def backward(ctx, grad_output):


implements = Int8QuantizedTrainingLinearWeight.implements
implements_torch_function = Int8QuantizedTrainingLinearWeight.implements_torch_function


@implements(torch.nn.functional.linear)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
return _Int8WeightOnlyLinear.apply(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,8 @@ def from_float(cls, weight):
return cls(weight)


@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
@Float32Tensor.implements_torch_function(torch.nn.functional.linear)
@Float32Tensor.implements(aten.linear.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ def _same_metadata(


implements = LinearActivationQuantizedTensor.implements
implements_torch_function = LinearActivationQuantizedTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/linear_activation_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,12 @@ def from_float(


implements = WeightTensorWithLinearActivationScaleMetadata.implements
implements_torch_function = (
WeightTensorWithLinearActivationScaleMetadata.implements_torch_function
)


@implements(torch.nn.functional.linear)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,12 @@ def to(self, *args, **kwargs):


implements = LinearActivationWeightObservedTensor.implements
implements_torch_function = (
LinearActivationWeightObservedTensor.implements_torch_function
)


@implements(torch.nn.functional.linear)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ def _create_new(self, new_value: torch.Tensor):


implements = _AffineFakeQuantizedTensor.implements
implements_torch_function = _AffineFakeQuantizedTensor.implements_torch_function


@implements(torch.nn.functional.linear)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ def from_hp(


implements = Float8Tensor.implements
implements_torch_function = Float8Tensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down Expand Up @@ -359,7 +361,7 @@ def _(func, types, args, kwargs):
)


@implements(torch.bmm)
@implements_torch_function(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def from_hp(


implements = Int4MarlinSparseTensor.implements
implements_torch_function = Int4MarlinSparseTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
from torchao.ops import marlin_24_gemm
from torchao.sparsity.marlin import marlin_24_workspace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,11 @@ def from_hp(


implements = Int4OpaqueTensor.implements
implements_torch_function = Int4OpaqueTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def from_hp(


implements = Int4PlainInt32Tensor.implements
implements_torch_function = Int4PlainInt32Tensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,11 @@ def from_int4_tensor(


implements = Int4PreshuffledTensor.implements
implements_torch_function = Int4PreshuffledTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down Expand Up @@ -256,7 +258,7 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
@implements_torch_function(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
Expand Down
6 changes: 4 additions & 2 deletions torchao/quantization/quantize_/workflows/int4/int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def from_hp(


implements = Int4Tensor.implements
implements_torch_function = Int4Tensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down Expand Up @@ -168,7 +170,7 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
@implements_torch_function(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,11 @@ def quant_2d(int_data_2d):


implements = Int4TilePackedTo4dTensor.implements
implements_torch_function = Int4TilePackedTo4dTensor.implements_torch_function


@implements([torch.nn.functional.linear, aten.linear.default])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
Loading
Loading