Skip to content

Commit 6bf856e

Browse files
authored
Add RMS Normalization variant (#2519)
Add RMS Normalization variant to support both orders for multiplying scale and normalized value. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 0c83c0d commit 6bf856e

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

onnxscript/rewriter/onnx_fusions/_rms_normalization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131

3232
class RmsNormFusion(pattern.RewriteRuleClassBase):
33+
def __init__(self, name: str, mul_order: bool):
34+
super().__init__(name)
35+
self._mul_order = mul_order
36+
3337
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
3438
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
3539
x_square = op.Pow(x, 2.0)
@@ -39,7 +43,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
3943
reciprocal_rms = op.Reciprocal(rms)
4044
normalized = op.Mul(x, reciprocal_rms)
4145
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
42-
return op.Mul(scale, normalized)
46+
# Workaround: limitation in pattern matcher doesn't support OrValue for return value (last node in pattern)
47+
if self._mul_order:
48+
return op.Mul(normalized, scale)
49+
else:
50+
return op.Mul(scale, normalized)
4351

4452
def check(
4553
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
@@ -76,9 +84,11 @@ def rewrite(self, op, x, scale, epsilon, **_):
7684
)
7785

7886

79-
_rule = RmsNormFusion.rule()
80-
rms_normalization_rules = [_rule]
81-
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
87+
_rule1 = RmsNormFusion.rule("RmsNormFusion1", mul_order=True)
88+
_rule2 = RmsNormFusion.rule("RmsNormFusion2", mul_order=False)
8289

90+
rms_normalization_rules = [_rule1, _rule2]
91+
92+
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
8393

8494
fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)

0 commit comments

Comments
 (0)