Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Relay] Add Rsqrt to SimplifyExpr (apache#12363)
Browse files Browse the repository at this point in the history
* Add Rsqrt to SimplifyExpr

* fix unit tests
  • Loading branch information
Matthew Brookhart authored and xinetzone committed Nov 25, 2022
1 parent 3286329 commit 049189a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("right_shift", False, broadcast_shape_func)

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("rsqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,8 @@ def __init__(self):
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
weighted = is_op("multiply")(div_out, self.gamma)
div_out2 = diff * is_op("rsqrt")(added_eps)
weighted = is_op("multiply")(div_out | div_out2, self.gamma)
added_bias = is_op("add")(weighted, self.beta)
self.pattern = added_bias

Expand Down
24 changes: 24 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,29 @@ class SimplifyConsecutiveAdd : public DFPatternRewrite {
DFPattern const2_;
};

class SimplifyRSqrt : public DFPatternRewrite {
public:
SimplifyRSqrt() {
x_ = IsWildcard();
numerator_ = IsWildcard();
auto sqrt = IsOp("sqrt");
pattern_ = IsOp("divide")({numerator_, sqrt({x_})});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
static const Op& op = Op::Get("rsqrt");
auto x = node_map[x_][0];
auto numerator = node_map[numerator_][0];
return Call(Op::Get("multiply"), {numerator, Call(op, {x})});
}

private:
/*! \brief Pattern input */
DFPattern x_;
DFPattern numerator_;
};

Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
// the rewrites will be applied in the given order, and repeated until fixed point
DFPatternRewriteComposer composer;
Expand All @@ -694,6 +717,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
composer.AddRewrite<SimplifyRSqrt>();
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,5 +584,24 @@ def expected():
assert tvm.ir.structural_equal(zzl, after)


def test_simplify_rsqrt():
shape = (32, 1, 1)
x = relay.var("x", shape=shape, dtype="float32")

def before(c):
return relay.const(c) / relay.sqrt(x)

def expected(c):
if c == 1:
return relay.rsqrt(x)
else:
return relay.const(c) * relay.rsqrt(x)

for c in [1.0, 2.0, 2.5]:
opt = run_opt_pass(before(c), transform.SimplifyExpr())
after = run_opt_pass(expected(c), transform.InferType())
assert tvm.ir.structural_equal(opt, after)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 049189a

Please sign in to comment.