Skip to content

Commit 19500bf

Browse files
committed
clean up a bit
1 parent 86e14a1 commit 19500bf

File tree

2 files changed

+13
-34
lines changed

2 files changed

+13
-34
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def from_hp(
248248
implements_torch_function = Float8Tensor.implements_torch_function
249249

250250

251-
@implements([aten.linear.default])
252-
@implements_torch_function([torch.nn.functional.linear])
251+
@implements(aten.linear.default)
252+
@implements_torch_function(torch.nn.functional.linear)
253253
def _(func, types, args, kwargs):
254254
input_tensor, weight_tensor, bias = (
255255
args[0],
@@ -259,35 +259,24 @@ def _(func, types, args, kwargs):
259259
return _float8_linear_impl(input_tensor, weight_tensor, bias)
260260

261261

262-
@implements([torch.matmul, aten.mm.default])
262+
@implements(aten.mm.default)
263+
@implements_torch_function(torch.matmul)
263264
def _(func, types, args, kwargs):
264265
input_tensor, weight_tensor = args[0], args[1]
265-
print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose)")
266266
return _float8_linear_impl(input_tensor, weight_tensor.t())
267267

268268

269-
@implements([aten.addmm_.default])
269+
@implements(aten.addmm_.default)
270270
def _(func, types, args, kwargs):
271271
output_tensor, input_tensor, weight_tensor = (
272272
args[0],
273273
args[1],
274274
args[2] if len(args) > 2 else None,
275275
)
276-
print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose), output_tensor = {output_tensor.shape}")
277276
out = _float8_linear_impl(input_tensor, weight_tensor.t())
278277
return output_tensor.copy_(out)
279278

280279

281-
@implements(aten.copy_.default)
282-
def _(func, types, args, kwargs):
283-
# For now, just support copying from a Float8Tensor to a Float8Tensor
284-
assert len(args) == 2
285-
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
286-
args[0].qdata.copy_(args[1].qdata, **kwargs)
287-
args[0].scale.copy_(args[1].scale, **kwargs)
288-
return args[0]
289-
290-
291280
def _float8_linear_impl(
292281
input_tensor: torch.Tensor,
293282
weight_tensor: torch.Tensor,
@@ -332,11 +321,11 @@ def _float8_linear_impl(
332321
wq = weight_tensor.qdata
333322
x_scale = input_tensor.scale
334323
w_scale = weight_tensor.scale
335-
if True: #_is_rowwise_scaled(weight_tensor):
324+
# TODO: fix this?
325+
if True: # _is_rowwise_scaled(weight_tensor):
336326
assert _is_rowwise_scaled(input_tensor), (
337327
"Input tensor must be rowwise block size"
338328
)
339-
print(f" * fbgemm op input = {xq.shape}, weight = {wq.shape}, input_scale = {x_scale.shape}, weight_scale = {w_scale.shape}")
340329
wq = wq.contiguous()
341330
res = torch.ops.fbgemm.f8f8bf16_rowwise(
342331
xq,
@@ -347,8 +336,6 @@ def _float8_linear_impl(
347336
use_fast_accum=mm_config.use_fast_accum,
348337
).reshape(out_shape)
349338
else:
350-
print("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!")
351-
breakpoint()
352339
assert _is_tensorwise_scaled(weight_tensor)
353340
assert _is_tensorwise_scaled(input_tensor)
354341
res = torch.ops.fbgemm.f8f8bf16(
@@ -746,21 +733,18 @@ def _(func, types, args, kwargs):
746733
self.mm_config,
747734
self.act_quant_kwargs,
748735
self.kernel_preference,
749-
self.dtype
736+
self.dtype,
750737
)
751738
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
752739

753740

754741
# This is called during _apply() to see if we can shallow
755742
# copy the content of one tensor into another. For now,
756743
# we only allow shallow copy if both tensors are `Float8Tensor`
757-
@implements(torch._has_compatible_shallow_copy_type)
744+
@implements_torch_function(torch._has_compatible_shallow_copy_type)
758745
def _(func, types, args, kwargs):
759746
assert len(args) == 2
760-
return (
761-
isinstance(args[0], Float8Tensor) and
762-
isinstance(args[1], Float8Tensor)
763-
)
747+
return isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
764748

765749

766750
@implements(aten.t.default)
@@ -775,7 +759,7 @@ def _(func, types, args, kwargs):
775759
self.mm_config,
776760
self.act_quant_kwargs,
777761
self.kernel_preference,
778-
self.dtype
762+
self.dtype,
779763
)
780764
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
781765

torchao/utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,6 @@ class MyTensor(torch.Tensor):
653653
...
654654
__torch_function__ = classmethod(_dispatch__torch_function__)
655655
"""
656-
#print(f"dispatch__torch_function__ {func}, cls = {cls}")
657656
kwargs = {} if kwargs is None else kwargs
658657
if (
659658
hasattr(cls, "_TORCH_FN_TABLE")
@@ -662,11 +661,8 @@ class MyTensor(torch.Tensor):
662661
):
663662
return cls._TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
664663
with torch._C.DisableTorchFunctionSubclass():
665-
try:
666-
return func(*args, **kwargs)
667-
except Exception as e:
668-
print("func is ", func if func is not None else "n/a", "cls is ", cls if cls is not None else "n/a", "args are", args, "kwargs are ", kwargs)
669-
raise e
664+
return func(*args, **kwargs)
665+
670666

671667
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
672668
"""Use this util function for a common `__torch_dispatch__` implementation
@@ -676,7 +672,6 @@ class MyTensor(torch.Tensor):
676672
...
677673
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
678674
"""
679-
#print(f"dispatched to {func}, cls is {cls}, types is {types}, args is {args}, kwargs is {kwargs}")
680675
if (
681676
hasattr(cls, "_ATEN_OP_TABLE")
682677
and cls in cls._ATEN_OP_TABLE

0 commit comments

Comments
 (0)