diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 04d0edb26d753..a6751933a88c7 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -685,6 +685,29 @@ class SimplifyConsecutiveAdd : public DFPatternRewrite { DFPattern const2_; }; +class SimplifyRSqrt : public DFPatternRewrite { + public: + SimplifyRSqrt() { + x_ = IsWildcard(); + numerator_ = IsWildcard(); + auto sqrt = IsOp("sqrt"); + pattern_ = IsOp("divide")({numerator_, sqrt({x_})}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + static const Op& op = Op::Get("rsqrt"); + auto x = node_map[x_][0]; + auto numerator = node_map[numerator_][0]; + return Call(Op::Get("multiply"), {numerator, Call(op, {x})}); + } + + private: + /*! \brief Pattern input */ + DFPattern x_; + DFPattern numerator_; +}; + Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { // the rewrites will be applied in the given order, and repeated until fixed point DFPatternRewriteComposer composer; @@ -694,6 +717,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 162ac6e73ddb9..837b15a48dc17 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -584,5 +584,24 @@ def expected(): assert tvm.ir.structural_equal(zzl, after) +def test_simplify_rsqrt(): + shape = (32, 1, 1) + x = relay.var("x", shape=shape, dtype="float32") + + def before(c): + return relay.const(c) / relay.sqrt(x) + + def expected(c): + if c == 1: + return relay.rsqrt(x) + else: + return relay.const(c) * relay.rsqrt(x) + + for c in [1.0, 2.0, 2.5]: + opt = run_opt_pass(before(c), transform.SimplifyExpr()) + after = run_opt_pass(expected(c), transform.InferType()) + assert tvm.ir.structural_equal(opt, after) + + if __name__ == "__main__": pytest.main([__file__])