@@ -673,10 +673,10 @@ def __init__(
673673 self .rmsnorm_matcher = MatcherRMSNorm (epsilon )
674674
675675 def get_inputs (self ):
676- input = torch .empty ([1 , 8 , 4 ], device = self .device , dtype = self .dtype )
677- weight = torch .empty ([4 ], device = self .device , dtype = self .dtype )
676+ input , weight = self .rmsnorm_matcher .inputs ()
678677
679- return [input , weight ]
678+ # input goes through allreduce first, always 16-bit
679+ return [input .to (self .dtype ), weight ]
680680
681681 def register (self , pm_pass : PatternMatcherPass ):
682682 def pattern (input : torch .Tensor , weight : torch .Tensor ):
@@ -728,14 +728,10 @@ def __init__(
728728 self .rmsnorm_matcher = MatcherFusedAddRMSNorm (epsilon )
729729
730730 def get_inputs (self ):
731- input = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
732- residual = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
733- weight = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
734- return [
735- residual ,
736- input ,
737- weight ,
738- ]
731+ input , residual , weight = self .rmsnorm_matcher .inputs ()
732+
733+ # input goes through allreduce first, always 16-bit
734+ return [residual , input .to (self .dtype ), weight ]
739735
740736 def register (self , pm_pass : PatternMatcherPass ):
741737 def pattern (residual : torch .Tensor , input : torch .Tensor , weight : torch .Tensor ):
@@ -802,10 +798,11 @@ def __init__(
802798
803799 def register (self , pm_pass : PatternMatcherPass ):
804800 def get_inputs ():
805- input = torch .zeros ([1 , 8 , 4 ], device = self .device , dtype = self .dtype )
806- weight = torch .empty ([4 ], device = self .device , dtype = self .dtype )
807- scale = torch .tensor (1.0 , device = self .device , dtype = torch .float32 )
808- return [input , weight , scale ]
801+ input , weight = self .rmsnorm_matcher .inputs ()
802+ _ , scale = self .quant_matcher .inputs ()
803+
804+ # input goes through allreduce first, always 16-bit
805+ return [input .to (self .dtype ), weight , scale ]
809806
810807 def pattern (
811808 input : torch .Tensor ,
@@ -871,18 +868,11 @@ def __init__(
871868
872869 def register (self , pm_pass : PatternMatcherPass ):
873870 def get_inputs ():
874- input = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
871+ input , residual , weight = self .rmsnorm_matcher .inputs ()
872+ _ , scale = self .quant_matcher .inputs ()
875873
876- residual = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
877- weight = torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
878- scale = torch .empty ([1 , 1 ], device = self .device , dtype = torch .float32 )
879-
880- return [
881- residual ,
882- input ,
883- weight ,
884- scale ,
885- ]
874+ # input goes through allreduce first, always 16-bit
875+ return [residual , input .to (self .dtype ), weight , scale ]
886876
887877 def pattern (
888878 residual : torch .Tensor ,
0 commit comments