@@ -248,8 +248,8 @@ def from_hp(
248248implements_torch_function = Float8Tensor .implements_torch_function
249249
250250
251- @implements ([ aten .linear .default ] )
252- @implements_torch_function ([ torch .nn .functional .linear ] )
251+ @implements (aten .linear .default )
252+ @implements_torch_function (torch .nn .functional .linear )
253253def _ (func , types , args , kwargs ):
254254 input_tensor , weight_tensor , bias = (
255255 args [0 ],
@@ -259,35 +259,24 @@ def _(func, types, args, kwargs):
259259 return _float8_linear_impl (input_tensor , weight_tensor , bias )
260260
261261
262- @implements ([torch .matmul , aten .mm .default ])
262+ @implements (aten .mm .default )
263+ @implements_torch_function (torch .matmul )
263264def _ (func , types , args , kwargs ):
264265 input_tensor , weight_tensor = args [0 ], args [1 ]
265- print (f"input = { input_tensor .shape } , weight = { weight_tensor .shape } , weight.block_size = { weight_tensor .block_size } (before transpose)" )
266266 return _float8_linear_impl (input_tensor , weight_tensor .t ())
267267
268268
269- @implements ([ aten .addmm_ .default ] )
269+ @implements (aten .addmm_ .default )
270270def _ (func , types , args , kwargs ):
271271 output_tensor , input_tensor , weight_tensor = (
272272 args [0 ],
273273 args [1 ],
274274 args [2 ] if len (args ) > 2 else None ,
275275 )
276- print (f"input = { input_tensor .shape } , weight = { weight_tensor .shape } , weight.block_size = { weight_tensor .block_size } (before transpose), output_tensor = { output_tensor .shape } " )
277276 out = _float8_linear_impl (input_tensor , weight_tensor .t ())
278277 return output_tensor .copy_ (out )
279278
280279
281- @implements (aten .copy_ .default )
282- def _ (func , types , args , kwargs ):
283- # For now, just support copying from a Float8Tensor to a Float8Tensor
284- assert len (args ) == 2
285- assert isinstance (args [0 ], Float8Tensor ) and isinstance (args [1 ], Float8Tensor )
286- args [0 ].qdata .copy_ (args [1 ].qdata , ** kwargs )
287- args [0 ].scale .copy_ (args [1 ].scale , ** kwargs )
288- return args [0 ]
289-
290-
291280def _float8_linear_impl (
292281 input_tensor : torch .Tensor ,
293282 weight_tensor : torch .Tensor ,
@@ -332,11 +321,11 @@ def _float8_linear_impl(
332321 wq = weight_tensor .qdata
333322 x_scale = input_tensor .scale
334323 w_scale = weight_tensor .scale
335- if True : #_is_rowwise_scaled(weight_tensor):
324+ # TODO: fix this?
325+ if True : # _is_rowwise_scaled(weight_tensor):
336326 assert _is_rowwise_scaled (input_tensor ), (
337327 "Input tensor must be rowwise block size"
338328 )
339- print (f" * fbgemm op input = { xq .shape } , weight = { wq .shape } , input_scale = { x_scale .shape } , weight_scale = { w_scale .shape } " )
340329 wq = wq .contiguous ()
341330 res = torch .ops .fbgemm .f8f8bf16_rowwise (
342331 xq ,
@@ -347,8 +336,6 @@ def _float8_linear_impl(
347336 use_fast_accum = mm_config .use_fast_accum ,
348337 ).reshape (out_shape )
349338 else :
350- print ("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!" )
351- breakpoint ()
352339 assert _is_tensorwise_scaled (weight_tensor )
353340 assert _is_tensorwise_scaled (input_tensor )
354341 res = torch .ops .fbgemm .f8f8bf16 (
@@ -746,21 +733,18 @@ def _(func, types, args, kwargs):
746733 self .mm_config ,
747734 self .act_quant_kwargs ,
748735 self .kernel_preference ,
749- self .dtype
736+ self .dtype ,
750737 )
751738 return return_and_correct_aliasing (func , args , kwargs , new_tensor )
752739
753740
754741# This is called during _apply() to see if we can shallow
755742# copy the content of one tensor into another. For now,
756743# we only allow shallow copy if both tensors are `Float8Tensor`
757- @implements (torch ._has_compatible_shallow_copy_type )
744+ @implements_torch_function (torch ._has_compatible_shallow_copy_type )
758745def _ (func , types , args , kwargs ):
759746 assert len (args ) == 2
760- return (
761- isinstance (args [0 ], Float8Tensor ) and
762- isinstance (args [1 ], Float8Tensor )
763- )
747+ return isinstance (args [0 ], Float8Tensor ) and isinstance (args [1 ], Float8Tensor )
764748
765749
766750@implements (aten .t .default )
@@ -775,7 +759,7 @@ def _(func, types, args, kwargs):
775759 self .mm_config ,
776760 self .act_quant_kwargs ,
777761 self .kernel_preference ,
778- self .dtype
762+ self .dtype ,
779763 )
780764 return return_and_correct_aliasing (func , args , kwargs , new_tensor )
781765
0 commit comments