@@ -65,8 +65,6 @@ def inputs(self) -> list[torch.Tensor]:
6565class MatcherRMSNorm (MatcherCustomOp ):
6666 def __init__ (self , epsilon : float , enabled : Optional [bool ] = None ):
6767 if enabled is None :
68- # TODO either pass config to enabled or set it globally
69- # (global during pass init seems reasonable)
7068 enabled = RMSNorm .enabled ()
7169
7270 super ().__init__ (enabled )
@@ -83,7 +81,6 @@ def forward_custom(
8381 self ,
8482 input : torch .Tensor ,
8583 weight : torch .Tensor ,
86- residual : Optional [torch .Tensor ] = None ,
8784 ) -> torch .Tensor :
8885 result = torch .empty_like (input )
8986 _ , result = auto_functionalized (
@@ -100,28 +97,15 @@ def forward_native(
10097 self ,
10198 input : torch .Tensor ,
10299 weight : torch .Tensor ,
103- residual : Optional [torch .Tensor ] = None ,
104100 ) -> torch .Tensor :
105- x = input .to (torch .float32 )
106- if residual is not None :
107- x = x + residual
108- residual = x .to (self .model_dtype )
109-
110- variance = x .pow (2 ).mean (dim = - 1 , keepdim = True )
111-
112- x = x * torch .rsqrt (variance + self .epsilon )
113- x = x .to (self .model_dtype )
114- if weight is not None :
115- x = x * weight
116-
117- return x if residual is None else (x , residual )
101+ return RMSNorm .forward_static (
102+ input , self .epsilon , input .size (- 1 ), self .model_dtype , weight
103+ )
118104
119105
120106class MatcherFusedAddRMSNorm (MatcherCustomOp ):
121107 def __init__ (self , epsilon : float , enabled : Optional [bool ] = None ):
122108 if enabled is None :
123- # TODO either pass config to enabled or set it globally
124- # (global during pass init seems reasonable)
125109 enabled = RMSNorm .enabled ()
126110
127111 super ().__init__ (enabled )
@@ -157,19 +141,9 @@ def forward_native(
157141 weight : torch .Tensor ,
158142 residual : torch .Tensor ,
159143 ) -> tuple [torch .Tensor , torch .Tensor ]:
160- x = input .to (torch .float32 )
161- if residual is not None :
162- x = x + residual
163- residual = x .to (self .model_dtype )
164-
165- variance = x .pow (2 ).mean (dim = - 1 , keepdim = True )
166-
167- x = x * torch .rsqrt (variance + self .epsilon )
168- x = x .to (self .model_dtype )
169- if weight is not None :
170- x = x * weight
171-
172- return x if residual is None else (x , residual )
144+ return RMSNorm .forward_static (
145+ input , self .epsilon , input .size (- 1 ), self .model_dtype , weight , residual
146+ )
173147
174148
175149class MatcherQuant :
0 commit comments