@@ -262,10 +262,32 @@ def _(func, types, args, kwargs):
262262@implements ([torch .matmul , aten .mm .default ])
263263def _ (func , types , args , kwargs ):
264264 input_tensor , weight_tensor = args [0 ], args [1 ]
265- print (f"input = { input_tensor .shape } , weight = { weight_tensor .shape } (before transpose)" )
265+ print (f"input = { input_tensor .shape } , weight = { weight_tensor .shape } , weight.block_size = { weight_tensor . block_size } (before transpose)" )
266266 return _float8_linear_impl (input_tensor , weight_tensor .t ())
267267
268268
269+ @implements ([aten .addmm_ .default ])
270+ def _ (func , types , args , kwargs ):
271+ output_tensor , input_tensor , weight_tensor = (
272+ args [0 ],
273+ args [1 ],
274+ args [2 ] if len (args ) > 2 else None ,
275+ )
276+ print (f"input = { input_tensor .shape } , weight = { weight_tensor .shape } , weight.block_size = { weight_tensor .block_size } (before transpose), output_tensor = { output_tensor .shape } " )
277+ out = _float8_linear_impl (input_tensor , weight_tensor .t ())
278+ return output_tensor .copy_ (out )
279+
280+
281+ @implements (aten .copy_ .default )
282+ def _ (func , types , args , kwargs ):
283+ # For now, just support copying from a Float8Tensor to a Float8Tensor
284+ assert len (args ) == 2
285+ assert isinstance (args [0 ], Float8Tensor ) and isinstance (args [1 ], Float8Tensor )
286+ args [0 ].qdata .copy_ (args [1 ].qdata , ** kwargs )
287+ args [0 ].scale .copy_ (args [1 ].scale , ** kwargs )
288+ return args [0 ]
289+
290+
269291def _float8_linear_impl (
270292 input_tensor : torch .Tensor ,
271293 weight_tensor : torch .Tensor ,
@@ -310,10 +332,12 @@ def _float8_linear_impl(
310332 wq = weight_tensor .qdata
311333 x_scale = input_tensor .scale
312334 w_scale = weight_tensor .scale
313- if _is_rowwise_scaled (weight_tensor ):
335+ if True : # _is_rowwise_scaled(weight_tensor):
314336 assert _is_rowwise_scaled (input_tensor ), (
315337 "Input tensor must be rowwise block size"
316338 )
339+ print (f" * fbgemm op input = { xq .shape } , weight = { wq .shape } , input_scale = { x_scale .shape } , weight_scale = { w_scale .shape } " )
340+ wq = wq .contiguous ()
317341 res = torch .ops .fbgemm .f8f8bf16_rowwise (
318342 xq ,
319343 wq ,
@@ -323,6 +347,8 @@ def _float8_linear_impl(
323347 use_fast_accum = mm_config .use_fast_accum ,
324348 ).reshape (out_shape )
325349 else :
350+ print ("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!" )
351+ breakpoint ()
326352 assert _is_tensorwise_scaled (weight_tensor )
327353 assert _is_tensorwise_scaled (input_tensor )
328354 res = torch .ops .fbgemm .f8f8bf16 (
@@ -727,10 +753,11 @@ def _(func, types, args, kwargs):
727753def _ (func , types , args , kwargs ):
728754 assert len (args ) == 1
729755 self = args [0 ]
756+ assert len (self .block_size ) == 2
730757 new_tensor = self .__class__ (
731758 self .qdata .t (),
732759 self .scale .t (),
733- self .block_size ,
760+ ( self .block_size [ 1 ], self . block_size [ 0 ]) ,
734761 self .mm_config ,
735762 self .act_quant_kwargs ,
736763 self .kernel_preference ,
0 commit comments