Skip to content

Commit

Permalink
Add Rsqrt to SimplifyExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Aug 10, 2022
1 parent 22ba659 commit 92a5e00
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,29 @@ class SimplifyConsecutiveAdd : public DFPatternRewrite {
DFPattern const2_;
};

class SimplifyRSqrt : public DFPatternRewrite {
public:
SimplifyRSqrt() {
x_ = IsWildcard();
numerator_ = IsWildcard();
auto sqrt = IsOp("sqrt");
pattern_ = IsOp("divide")({numerator_, sqrt({x_})});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
static const Op& op = Op::Get("rsqrt");
auto x = node_map[x_][0];
auto numerator = node_map[numerator_][0];
return Call(Op::Get("multiply"), {numerator, Call(op, {x})});
}

private:
/*! \brief Pattern input */
DFPattern x_;
DFPattern numerator_;
};

Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
// the rewrites will be applied in the given order, and repeated until fixed point
DFPatternRewriteComposer composer;
Expand All @@ -694,6 +717,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
composer.AddRewrite<SimplifyRSqrt>();
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,5 +584,24 @@ def expected():
assert tvm.ir.structural_equal(zzl, after)


def test_simplify_rsqrt():
shape = (32, 1, 1)
x = relay.var("x", shape=shape, dtype="float32")

def before(c):
return relay.const(c) / relay.sqrt(x)

def expected(c):
if c == 1:
return relay.rsqrt(x)
else:
return relay.const(c) * relay.rsqrt(x)

for c in [1.0, 2.0, 2.5]:
opt = run_opt_pass(before(c), transform.SimplifyExpr())
after = run_opt_pass(expected(c), transform.InferType())
assert tvm.ir.structural_equal(opt, after)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 92a5e00

Please sign in to comment.