Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Simplify the square of a binomial #14580

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,87 @@ class SimplifyAdd : public DFPatternRewrite {
DFPattern y_;
};

/*! \brief Simplifying a * x * x + b * x * y + c * y * y to a * (x + p * y) * (x + q * y) */
class SimplifyBinomial : public DFPatternRewrite {
public:
SimplifyBinomial() {
x_ = IsWildcard();
y_ = IsWildcard();
a_ = IsConstant();
b_ = IsConstant();
c_ = IsConstant();
DFPattern add = IsOp("add");
DFPattern mul = IsOp("multiply");
DFPattern x_sq = mul({a_, mul({x_, x_})}) || mul({x_, mul({a_, x_})}) || mul({x_, x_});
DFPattern xy = mul({b_, mul({x_, y_})}) || mul({x_, mul({b_, y_})}) ||
mul({y_, mul({b_, x_})}) || mul({x_, y_});
DFPattern y_sq = mul({c_, mul({y_, y_})}) || mul({y_, mul({c_, y_})}) || mul({y_, y_});

pattern_ = add({add({xy, x_sq}), y_sq}) || add({add({xy, y_sq}), x_sq}) ||
add({add({x_sq, y_sq}), xy});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
Type pre_type = pre->checked_type_;
auto dtype = pre_type.as<TensorTypeNode>()->dtype;
auto x = node_map[x_][0];
auto y = node_map[y_][0];
double a_val = 1;
double b_val = 1;
double c_val = 1;
double* vals[] = {&a_val, &b_val, &c_val};
DFPattern nodes[] = {a_, b_, c_};
for (int i = 0; i < 3; i++) {
if (node_map.count(nodes[i]) > 0) {
if (dtype == DataType::Int(32, 1))
*vals[i] = static_cast<int*>(
transform::FoldConstantExpr(node_map[nodes[i]][0]).as<ConstantNode>()->data->data)[0];
else if (dtype == DataType::Float(32, 1))
*vals[i] = static_cast<float*>(
transform::FoldConstantExpr(node_map[nodes[i]][0]).as<ConstantNode>()->data->data)[0];
else if (dtype == DataType::Float(64, 1))
*vals[i] = static_cast<double*>(
transform::FoldConstantExpr(node_map[nodes[i]][0]).as<ConstantNode>()->data->data)[0];
}
}
if (c_val == 1 && a_val > 1) {
auto temp_exp = x;
x = y;
y = temp_exp;
float temp_val = a_val;
a_val = c_val;
c_val = temp_val;
}

double sub_value = b_val * b_val - 4 * a_val * c_val;
if (sub_value < 0) return pre;
bool same_multiplicands = sub_value < 10e-5;

double discriminant = std::sqrt(sub_value);
Expr first_val = MakeConstantScalar(dtype, (b_val + discriminant) / (2 * a_val));
Expr second_val = same_multiplicands
? first_val
: MakeConstantScalar(dtype, (b_val - discriminant) / (2 * a_val));

Expr first_multiplicand = Call(Op::Get("add"), {x, Call(Op::Get("multiply"), {y, first_val})});
Expr second_multiplicand =
same_multiplicands ? first_multiplicand
: Call(Op::Get("add"), {x, Call(Op::Get("multiply"), {y, second_val})});
Expr a = MakeConstantScalar(dtype, a_val);
return Call(Op::Get("multiply"),
{a, Call(Op::Get("multiply"), {first_multiplicand, second_multiplicand})});
}

private:
/*! \brief Pattern input */
DFPattern a_;
DFPattern b_;
DFPattern c_;
DFPattern x_;
DFPattern y_;
};

/*! \brief Simplifying x/sqrt to x*sqrt */
class SimplifyRSqrt : public DFPatternRewrite {
public:
Expand Down Expand Up @@ -966,6 +1047,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<SimplifyDQArgSort>();
composer.AddRewrite<SimplifyClipAndConsecutiveCast>();
composer.AddRewrite<SimplifyCastClip>();
composer.AddRewrite<SimplifyBinomial>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}

Expand Down
99 changes: 97 additions & 2 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from math import sqrt
import pytest
import tvm
from tvm import relay
Expand Down Expand Up @@ -374,7 +375,7 @@ def check(x, y=None, do_nothing=False):
x = relay.var("x", shape=shape, dtype=dtype)
x = run_opt_pass(x, transform.InferType())

for (op, op_like, id_op, const) in [
for op, op_like, id_op, const in [
(relay.zeros, relay.zeros_like, relay.add, relay.const(0, dtype)),
(relay.ones, relay.ones_like, relay.multiply, relay.const(1, dtype)),
]:
Expand All @@ -389,7 +390,7 @@ def check(x, y=None, do_nothing=False):
check(id_op(x, op([2] + shape, dtype)), do_nothing=True)
check(id_op(op([2] + shape, dtype), x), do_nothing=True)

for (op, op_like, id_op, const) in [
for op, op_like, id_op, const in [
(relay.zeros, relay.zeros_like, relay.subtract, relay.const(0, dtype)),
(relay.ones, relay.ones_like, relay.divide, relay.const(1, dtype)),
]:
Expand Down Expand Up @@ -744,5 +745,99 @@ def expected():
assert tvm.ir.structural_equal(opt, ref)


def test_binomials():
def check_simple_fold(origin_exprs, expect_exprs):
for origin_expr in origin_exprs:
simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr())
match = False
for expected in expect_exprs:
e = run_opt_pass(expected, transform.EliminateCommonSubexpr())
match = match or tvm.ir.structural_equal(simple_expr, e)
if match:
break
assert match

def gen_expected_expressions(x, y, a, b, c, dtype):
if c == 1 and a > 1:
swap = a
a = c
c = swap
swap = x
x = y
y = swap

det = b * b - 4 * a * c
if det < 0:
return gen_expressions(x, y, a, b, c)

p_val = (b + sqrt(det)) / (2 * a)
q_val = (b - sqrt(det)) / (2 * a)
p = relay.const(p_val, dtype=dtype)
q = relay.const(q_val, dtype=dtype)
first_exp = [x + y, y + x] if p_val == 1 else [x + p * y, p * y + x, x + y * p, y * p + x]
second_exp = [x + y, y + x] if q_val == 1 else [x + q * y, q * y + x, x + y * q, y * q + x]
final_exp = []
for f in first_exp:
for s in second_exp:
final_exp.append(f * s)
if not p_val == q_val:
final_exp.append(s * f)
return final_exp

def gen_expressions(x, y, a, b, c):
first_exp = [x * x] if a == 1 else [a * x * x, x * a * x, x * x * a]
second_exp = (
[x * y, y * x]
if b == 1
else [b * x * y, x * b * y, x * y * b, b * y * x, y * b * x, y * x * b]
)
third_exp = [y * y] if c == 1 else [c * y * y, y * c * y, y * y * c]
final_exp = []
for f in first_exp:
for s in second_exp:
for t in third_exp:
final_exp.append(f + s + t)
final_exp.append(f + t + s)
final_exp.append(s + f + t)
final_exp.append(s + t + f)
final_exp.append(t + f + s)
final_exp.append(t + s + f)
return final_exp

n = 5
dtypes = ["int32", "float32", "float64"]
for dtype in dtypes:
x = relay.var("x", shape=(n,), dtype=dtype)
y = relay.var("y", shape=(n,), dtype=dtype)

a = relay.const(1, dtype=dtype)
b = relay.const(2, dtype=dtype)
c = relay.const(1, dtype=dtype)
origin_exprs = gen_expressions(x, y, a, b, c)
expect_expr = gen_expected_expressions(x, y, 1, 2, 1, dtype)
check_simple_fold(origin_exprs, expect_expr)

a = relay.const(6, dtype=dtype)
b = relay.const(5, dtype=dtype)
c = relay.const(1, dtype=dtype)
origin_exprs = gen_expressions(x, y, a, b, c)
expect_expr = gen_expected_expressions(x, y, 6, 5, 1, dtype)
check_simple_fold(origin_exprs, expect_expr)

a = relay.const(1, dtype=dtype)
b = relay.const(1, dtype=dtype)
c = relay.const(1, dtype=dtype)
origin_exprs = gen_expressions(x, y, a, b, c)
expect_expr = gen_expected_expressions(x, y, 1, 1, 1, dtype)
check_simple_fold(origin_exprs, expect_expr)

a = relay.const(1, dtype=dtype)
b = relay.const(4, dtype=dtype)
c = relay.const(4, dtype=dtype)
origin_exprs = gen_expressions(x, y, a, b, c)
expect_expr = gen_expected_expressions(x, y, 1, 4, 4, dtype)
check_simple_fold(origin_exprs, expect_expr)


if __name__ == "__main__":
tvm.testing.main()