@@ -248,8 +248,8 @@ def from_hp(
248248implements_torch_function  =  Float8Tensor .implements_torch_function 
249249
250250
251- @implements ([ aten .linear .default ] ) 
252- @implements_torch_function ([ torch .nn .functional .linear ] ) 
251+ @implements (aten .linear .default ) 
252+ @implements_torch_function (torch .nn .functional .linear ) 
253253def  _ (func , types , args , kwargs ):
254254    input_tensor , weight_tensor , bias  =  (
255255        args [0 ],
@@ -259,35 +259,24 @@ def _(func, types, args, kwargs):
259259    return  _float8_linear_impl (input_tensor , weight_tensor , bias )
260260
261261
262- @implements ([torch .matmul , aten .mm .default ]) 
262+ @implements (aten .mm .default ) 
263+ @implements_torch_function (torch .matmul ) 
263264def  _ (func , types , args , kwargs ):
264265    input_tensor , weight_tensor  =  args [0 ], args [1 ]
265-     print (f"input = { input_tensor .shape }  , weight = { weight_tensor .shape }  , weight.block_size = { weight_tensor .block_size }   (before transpose)" )
266266    return  _float8_linear_impl (input_tensor , weight_tensor .t ())
267267
268268
269- @implements ([ aten .addmm_ .default ] ) 
269+ @implements (aten .addmm_ .default ) 
270270def  _ (func , types , args , kwargs ):
271271    output_tensor , input_tensor , weight_tensor  =  (
272272        args [0 ],
273273        args [1 ],
274274        args [2 ] if  len (args ) >  2  else  None ,
275275    )
276-     print (f"input = { input_tensor .shape }  , weight = { weight_tensor .shape }  , weight.block_size = { weight_tensor .block_size }   (before transpose), output_tensor = { output_tensor .shape }  " )
277276    out  =  _float8_linear_impl (input_tensor , weight_tensor .t ())
278277    return  output_tensor .copy_ (out )
279278
280279
281- @implements (aten .copy_ .default ) 
282- def  _ (func , types , args , kwargs ):
283-     # For now, just support copying from a Float8Tensor to a Float8Tensor 
284-     assert  len (args ) ==  2 
285-     assert  isinstance (args [0 ], Float8Tensor ) and  isinstance (args [1 ], Float8Tensor )
286-     args [0 ].qdata .copy_ (args [1 ].qdata , ** kwargs )
287-     args [0 ].scale .copy_ (args [1 ].scale , ** kwargs )
288-     return  args [0 ]
289- 
290- 
291280def  _float8_linear_impl (
292281    input_tensor : torch .Tensor ,
293282    weight_tensor : torch .Tensor ,
@@ -332,11 +321,11 @@ def _float8_linear_impl(
332321            wq  =  weight_tensor .qdata 
333322            x_scale  =  input_tensor .scale 
334323            w_scale  =  weight_tensor .scale 
335-             if  True : #_is_rowwise_scaled(weight_tensor): 
324+             # TODO: fix this? 
325+             if  True :  # _is_rowwise_scaled(weight_tensor): 
336326                assert  _is_rowwise_scaled (input_tensor ), (
337327                    "Input tensor must be rowwise block size" 
338328                )
339-                 print (f"        * fbgemm op input = { xq .shape }  , weight = { wq .shape }  , input_scale = { x_scale .shape }  , weight_scale = { w_scale .shape }  " )
340329                wq  =  wq .contiguous ()
341330                res  =  torch .ops .fbgemm .f8f8bf16_rowwise (
342331                    xq ,
@@ -347,8 +336,6 @@ def _float8_linear_impl(
347336                    use_fast_accum = mm_config .use_fast_accum ,
348337                ).reshape (out_shape )
349338            else :
350-                 print ("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!" )
351-                 breakpoint ()
352339                assert  _is_tensorwise_scaled (weight_tensor )
353340                assert  _is_tensorwise_scaled (input_tensor )
354341                res  =  torch .ops .fbgemm .f8f8bf16 (
@@ -746,21 +733,18 @@ def _(func, types, args, kwargs):
746733        self .mm_config ,
747734        self .act_quant_kwargs ,
748735        self .kernel_preference ,
749-         self .dtype 
736+         self .dtype , 
750737    )
751738    return  return_and_correct_aliasing (func , args , kwargs , new_tensor )
752739
753740
754741# This is called during _apply() to see if we can shallow 
755742# copy the content of one tensor into another. For now, 
756743# we only allow shallow copy if both tensors are `Float8Tensor` 
757- @implements (torch ._has_compatible_shallow_copy_type ) 
744+ @implements_torch_function (torch ._has_compatible_shallow_copy_type ) 
758745def  _ (func , types , args , kwargs ):
759746    assert  len (args ) ==  2 
760-     return  (
761-         isinstance (args [0 ], Float8Tensor ) and 
762-         isinstance (args [1 ], Float8Tensor )
763-     )
747+     return  isinstance (args [0 ], Float8Tensor ) and  isinstance (args [1 ], Float8Tensor )
764748
765749
766750@implements (aten .t .default ) 
@@ -775,7 +759,7 @@ def _(func, types, args, kwargs):
775759        self .mm_config ,
776760        self .act_quant_kwargs ,
777761        self .kernel_preference ,
778-         self .dtype 
762+         self .dtype , 
779763    )
780764    return  return_and_correct_aliasing (func , args , kwargs , new_tensor )
781765
0 commit comments