Skip to content

Commit d63cbe9

Browse files
committed
runs without errors
1 parent 26a5cd2 commit d63cbe9

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,32 @@ def _(func, types, args, kwargs):
262262
@implements([torch.matmul, aten.mm.default])
263263
def _(func, types, args, kwargs):
264264
input_tensor, weight_tensor = args[0], args[1]
265-
print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape} (before transpose)")
265+
print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose)")
266266
return _float8_linear_impl(input_tensor, weight_tensor.t())
267267

268268

269+
@implements([aten.addmm_.default])
270+
def _(func, types, args, kwargs):
271+
output_tensor, input_tensor, weight_tensor = (
272+
args[0],
273+
args[1],
274+
args[2] if len(args) > 2 else None,
275+
)
276+
print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose), output_tensor = {output_tensor.shape}")
277+
out = _float8_linear_impl(input_tensor, weight_tensor.t())
278+
return output_tensor.copy_(out)
279+
280+
281+
@implements(aten.copy_.default)
282+
def _(func, types, args, kwargs):
283+
# For now, just support copying from a Float8Tensor to a Float8Tensor
284+
assert len(args) == 2
285+
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
286+
args[0].qdata.copy_(args[1].qdata, **kwargs)
287+
args[0].scale.copy_(args[1].scale, **kwargs)
288+
return args[0]
289+
290+
269291
def _float8_linear_impl(
270292
input_tensor: torch.Tensor,
271293
weight_tensor: torch.Tensor,
@@ -310,10 +332,12 @@ def _float8_linear_impl(
310332
wq = weight_tensor.qdata
311333
x_scale = input_tensor.scale
312334
w_scale = weight_tensor.scale
313-
if _is_rowwise_scaled(weight_tensor):
335+
if True: #_is_rowwise_scaled(weight_tensor):
314336
assert _is_rowwise_scaled(input_tensor), (
315337
"Input tensor must be rowwise block size"
316338
)
339+
print(f" * fbgemm op input = {xq.shape}, weight = {wq.shape}, input_scale = {x_scale.shape}, weight_scale = {w_scale.shape}")
340+
wq = wq.contiguous()
317341
res = torch.ops.fbgemm.f8f8bf16_rowwise(
318342
xq,
319343
wq,
@@ -323,6 +347,8 @@ def _float8_linear_impl(
323347
use_fast_accum=mm_config.use_fast_accum,
324348
).reshape(out_shape)
325349
else:
350+
print("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!")
351+
breakpoint()
326352
assert _is_tensorwise_scaled(weight_tensor)
327353
assert _is_tensorwise_scaled(input_tensor)
328354
res = torch.ops.fbgemm.f8f8bf16(
@@ -727,10 +753,11 @@ def _(func, types, args, kwargs):
727753
def _(func, types, args, kwargs):
728754
assert len(args) == 1
729755
self = args[0]
756+
assert len(self.block_size) == 2
730757
new_tensor = self.__class__(
731758
self.qdata.t(),
732759
self.scale.t(),
733-
self.block_size,
760+
(self.block_size[1], self.block_size[0]),
734761
self.mm_config,
735762
self.act_quant_kwargs,
736763
self.kernel_preference,

0 commit comments

Comments
 (0)