Skip to content

Commit 615de5e

Browse files
committed
address comments
1 parent 43c7259 commit 615de5e

File tree

2 files changed

+35
-48
lines changed

2 files changed

+35
-48
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,15 @@ def _(func, types, args, kwargs):
268268

269269
@implements(aten.addmm_.default)
270270
def _(func, types, args, kwargs):
271-
output_tensor, input_tensor, weight_tensor = (
271+
bias_tensor, input_tensor, weight_tensor = (
272272
args[0],
273273
args[1],
274-
args[2] if len(args) > 2 else None,
274+
args[2],
275275
)
276+
assert kwargs.get("alpha", 1) == 1, "only alpha=1 is supported"
277+
assert kwargs.get("beta", 1) == 1, "only beta=1 is supported"
276278
out = _float8_mm_impl(input_tensor, weight_tensor)
277-
return output_tensor.copy_(out)
279+
return bias_tensor.add_(out)
278280

279281

280282
def _float8_mm_impl(
@@ -708,51 +710,6 @@ def _(func, types, args, kwargs):
708710
return return_and_correct_aliasing(func, args, kwargs, new)
709711

710712

711-
@implements(torch.ops.aten.to.dtype_layout)
712-
def _(func, types, args, kwargs):
713-
# only support kwargs for now
714-
assert len(args) == 1
715-
self = args[0]
716-
# only support dtype, layout, and device for now
717-
for k in kwargs.keys():
718-
assert k in ["dtype", "layout", "device"]
719-
# only support same dtype and layout
720-
# different dtype and layout has undefined behavior
721-
if "dtype" in kwargs:
722-
assert kwargs["dtype"] == self.dtype
723-
if "layout" in kwargs:
724-
assert kwargs["layout"] == self.layout
725-
# if device is the same, treat this like a no-op
726-
device = kwargs.get("device")
727-
if device == self.device:
728-
return self
729-
# otherwise, move all inner tensors to the new device
730-
new_tensor = self.__class__(
731-
func(self.qdata, device=device),
732-
func(self.scale, device=device),
733-
self.block_size,
734-
self.mm_config,
735-
self.act_quant_kwargs,
736-
self.kernel_preference,
737-
self.dtype,
738-
)
739-
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
740-
741-
742-
# This is called during _apply() to see if we can shallow
743-
# copy the content of one tensor into another. For now,
744-
# we only allow shallow copy if both tensors are `Float8Tensor`
745-
# and have the same shape.
746-
@implements_torch_function(torch._has_compatible_shallow_copy_type)
747-
def _(func, types, args, kwargs):
748-
assert len(args) == 2
749-
return (
750-
isinstance(args[0], Float8Tensor)
751-
and isinstance(args[1], Float8Tensor)
752-
and args[0].shape == args[1].shape
753-
)
754-
755-
756713
@implements(aten.t.default)
757714
def _(func, types, args, kwargs):
758715
assert len(args) == 1

torchao/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,36 @@ def _implements_common_tensor_ops(cls):
507507
implements_torch_function = cls.implements_torch_function
508508
aten = torch.ops.aten
509509

510+
@implements(torch.ops.aten.to.dtype_layout)
511+
def _(func, types, args, kwargs):
512+
# only support kwargs for now
513+
assert len(args) == 1
514+
self = args[0]
515+
# only support dtype, layout, and device for now
516+
for k in kwargs.keys():
517+
assert k in ["dtype", "layout", "device"]
518+
# only support same dtype and layout
519+
# different dtype and layout has undefined behavior
520+
if "dtype" in kwargs:
521+
assert kwargs["dtype"] == self.dtype
522+
if "layout" in kwargs:
523+
assert kwargs["layout"] == self.layout
524+
# if device is the same, treat this like a no-op
525+
device = kwargs.get("device")
526+
if device == self.device:
527+
return self
528+
new_tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device))
529+
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
530+
531+
# This is called during _apply() to see if we can shallow
532+
# copy the content of one tensor into another. For now,
533+
# we only allow shallow copy if both tensors are of the
534+
# same type and have the same shape.
535+
@implements_torch_function(torch._has_compatible_shallow_copy_type)
536+
def _(func, types, args, kwargs):
537+
assert len(args) == 2
538+
return type(args[0]) == type(args[1]) and args[0].shape == args[1].shape
539+
510540
@implements_torch_function(
511541
[
512542
torch.Tensor.contiguous,

0 commit comments

Comments
 (0)