diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 37cb263c489d..a04199f6a5b1 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -292,6 +292,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("right_shift", False, broadcast_shape_func) register_shape_func("sqrt", False, elemwise_shape_func) +register_shape_func("rsqrt", False, elemwise_shape_func) register_shape_func("negative", False, elemwise_shape_func) register_shape_func("exp", False, elemwise_shape_func) register_shape_func("tan", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 4ef342a26b0b..f7752e41b056 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -856,7 +856,8 @@ def __init__(self): added_eps = is_op("add")(mp1, eps) deno = is_op("sqrt")(added_eps) div_out = is_op("divide")(diff, deno) - weighted = is_op("multiply")(div_out, self.gamma) + div_out2 = diff * is_op("rsqrt")(added_eps) + weighted = is_op("multiply")(div_out | div_out2, self.gamma) added_bias = is_op("add")(weighted, self.beta) self.pattern = added_bias diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 04d0edb26d75..a6751933a88c 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -685,6 +685,29 @@ class SimplifyConsecutiveAdd : public DFPatternRewrite { DFPattern const2_; }; +class SimplifyRSqrt : public DFPatternRewrite { + public: + SimplifyRSqrt() { + x_ = IsWildcard(); + numerator_ = IsWildcard(); + auto sqrt = IsOp("sqrt"); + pattern_ = IsOp("divide")({numerator_, sqrt({x_})}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + static const Op& op = Op::Get("rsqrt"); + auto x = node_map[x_][0]; + auto numerator = node_map[numerator_][0]; + return Call(Op::Get("multiply"), {numerator, Call(op, {x})}); + } + + private: + /*! \brief Pattern input */ + DFPattern x_; + DFPattern numerator_; +}; + Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { // the rewrites will be applied in the given order, and repeated until fixed point DFPatternRewriteComposer composer; @@ -694,6 +717,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 162ac6e73ddb..837b15a48dc1 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -584,5 +584,24 @@ def expected(): assert tvm.ir.structural_equal(zzl, after) +def test_simplify_rsqrt(): + shape = (32, 1, 1) + x = relay.var("x", shape=shape, dtype="float32") + + def before(c): + return relay.const(c) / relay.sqrt(x) + + def expected(c): + if c == 1: + return relay.rsqrt(x) + else: + return relay.const(c) * relay.rsqrt(x) + + for c in [1.0, 2.0, 2.5]: + opt = run_opt_pass(before(c), transform.SimplifyExpr()) + after = run_opt_pass(expected(c), transform.InferType()) + assert tvm.ir.structural_equal(opt, after) + + if __name__ == "__main__": pytest.main([__file__])