@@ -268,13 +268,15 @@ def _(func, types, args, kwargs):
268268
269269@implements (aten .addmm_ .default )
270270def _ (func , types , args , kwargs ):
271- output_tensor , input_tensor , weight_tensor = (
271+ bias_tensor , input_tensor , weight_tensor = (
272272 args [0 ],
273273 args [1 ],
274- args [2 ] if len ( args ) > 2 else None ,
274+ args [2 ],
275275 )
276+ assert kwargs .get ("alpha" , 1 ) == 1 , "only alpha=1 is supported"
277+ assert kwargs .get ("beta" , 1 ) == 1 , "only beta=1 is supported"
276278 out = _float8_mm_impl (input_tensor , weight_tensor )
277- return output_tensor . copy_ (out )
279+ return bias_tensor . add_ (out )
278280
279281
280282def _float8_mm_impl (
@@ -708,51 +710,6 @@ def _(func, types, args, kwargs):
708710 return return_and_correct_aliasing (func , args , kwargs , new )
709711
710712
711- @implements (torch .ops .aten .to .dtype_layout )
712- def _ (func , types , args , kwargs ):
713- # only support kwargs for now
714- assert len (args ) == 1
715- self = args [0 ]
716- # only support dtype, layout, and device for now
717- for k in kwargs .keys ():
718- assert k in ["dtype" , "layout" , "device" ]
719- # only support same dtype and layout
720- # different dtype and layout has undefined behavior
721- if "dtype" in kwargs :
722- assert kwargs ["dtype" ] == self .dtype
723- if "layout" in kwargs :
724- assert kwargs ["layout" ] == self .layout
725- # if device is the same, treat this like a no-op
726- device = kwargs .get ("device" )
727- if device == self .device :
728- return self
729- # otherwise, move all inner tensors to the new device
730- new_tensor = self .__class__ (
731- func (self .qdata , device = device ),
732- func (self .scale , device = device ),
733- self .block_size ,
734- self .mm_config ,
735- self .act_quant_kwargs ,
736- self .kernel_preference ,
737- self .dtype ,
738- )
739- return return_and_correct_aliasing (func , args , kwargs , new_tensor )
740-
741-
742- # This is called during _apply() to see if we can shallow
743- # copy the content of one tensor into another. For now,
744- # we only allow shallow copy if both tensors are `Float8Tensor`
745- # and have the same shape.
746- @implements_torch_function (torch ._has_compatible_shallow_copy_type )
747- def _ (func , types , args , kwargs ):
748- assert len (args ) == 2
749- return (
750- isinstance (args [0 ], Float8Tensor )
751- and isinstance (args [1 ], Float8Tensor )
752- and args [0 ].shape == args [1 ].shape
753- )
754-
755-
756713@implements (aten .t .default )
757714def _ (func , types , args , kwargs ):
758715 assert len (args ) == 1
0 commit comments