Skip to content

Commit

Permalink
Simplify full broadcast (apache#7423)
Browse files Browse the repository at this point in the history
* convert argwhere(full(const)) to reshape(arange())

* Add IsWildcard syntatic sugar

* add a simplify expression to fold full into broadcast ops

* Allow constant folding of full-like ops after SimplifyExpr

* fix a bug with the Attr Pattern matching

* remove skip_list
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Mar 2, 2021
1 parent 4cb1a2b commit fbbd3f7
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 37 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ class DominatorPattern : public DFPattern {
DFPattern IsVar(const String& name);
/*! \brief Syntatic Sugar for creating a ConstantPattern */
DFPattern IsConstant();
/*! \brief Syntatic Sugar for creating a WildcardPattern */
DFPattern IsWildcard();
/*! \brief Syntatic Sugar for creating a ExprPattern */
DFPattern IsExpr(const Expr& expr);
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/
Expand Down
8 changes: 7 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,12 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
if (Op::HasAttrMap(attr_name)) {
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
matches = MatchRetValue(attr_value, op_map[op]);
matches &= MatchRetValue(attr_value, op_map[op]);
} else {
matches = false;
}
} else {
matches = false;
}
}
} else if (auto* op = expr.as<CallNode>()) {
Expand Down Expand Up @@ -196,6 +200,8 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
break;
}
}
} else {
matches = false;
}
return matches;
}
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) {
}
DFPattern IsVar(const String& name) { return VarPattern(name); }
DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
DFPattern IsWildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }
Expand Down
6 changes: 6 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,

Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);

Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);

Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
6 changes: 4 additions & 2 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::shape(inputs[0], param->dtype)};
}

TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
Expr MakeShapeOf(Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
return Call(op, {data}, Attrs(attrs), {});
});
}

TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);

RELAY_REGISTER_OP("shape_of")
.describe(R"code(Returns a tensor representing the shape of a tensor.
Expand Down
5 changes: 0 additions & 5 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ class ConstantFolder : public MixedModeMutator {
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};

auto origin_args = call->args;
call = post.as<CallNode>();
// We don't constant fold function with zero arguments.
Expand All @@ -158,9 +156,6 @@ class ConstantFolder : public MixedModeMutator {
if (call->args.size() == 0) return post;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return post;
if (skip_list.count(op->name)) {
return post;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return post;
// Try to evaluate shape_of op
Expand Down
111 changes: 98 additions & 13 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,38 @@
#include <tvm/support/logging.h>

#include "../op/tensor/transform.h"
#include "pattern_utils.h"

namespace tvm {
namespace relay {

class SimplifyPattern {
public:
virtual Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const = 0;

DFPattern pattern() const { return pattern_; }

protected:
/*! \brief Pattern for rewriting */
DFPattern pattern_;
};

/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
*/
class SimplifyReshape {
class SimplifyReshape : public SimplifyPattern {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
x_ = IsWildcard();
auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
pattern_ = reshape1({reshape2({x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
bool const_shape = true;
Array<Integer> newshape;
Expand All @@ -63,13 +77,82 @@ class SimplifyReshape {
return post;
}

DFPattern pattern() const { return pattern_; }

private:
/*! \brief Pattern input */
DFPattern x_;
/*! \brief Pattern for consecutive reshape or reverse_reshape ops */
DFPattern pattern_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
class FullElementwise : public SimplifyPattern {
public:
FullElementwise() {
x_ = IsWildcard();
data_ = IsWildcard();
value_ = IsConstant();

full_ = IsOp("full")({value_}) || IsOp("full_like")({data_, value_});
ones_ = IsOp("ones")({}) || IsOp("ones_like")({data_});
zeros_ = IsOp("zeros")({}) || IsOp("zeros_like")({data_});

Map<String, ObjectRef> attrs;
attrs.Set("TOpPattern", Integer(static_cast<int>(kBroadcast)));
DFPattern op = IsWildcard().HasAttr(attrs);
DFPattern full = full_ || ones_ || zeros_;
pattern_ = op({full, x_}) || op({x_, full});
}

Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
ICHECK(call);
Type pre_type = pre->checked_type_;
ICHECK(pre_type.as<TensorTypeNode>());
auto dtype = pre_type.as<TensorTypeNode>()->dtype;
auto x = node_map[x_][0];
bool is_left = post.as<CallNode>()->args[1] == x;
Type x_type;
if (is_left) {
x_type = call->args[1]->checked_type_;
} else {
x_type = call->args[0]->checked_type_;
}

if (StructuralEqual()(x_type, pre_type)) {
Expr value;
if (node_map.count(full_)) {
value = node_map[value_][0];
ICHECK(IsConstScalar(value));
} else if (node_map.count(ones_)) {
value = MakeConstantScalar(dtype, 1);
} else if (node_map.count(zeros_)) {
value = MakeConstantScalar(dtype, 0);
} else {
ICHECK(false) << "Didn't find a full op while matching full + elementwise";
}
if (is_left) {
return Call(call->op, {value, x}, call->attrs, call->type_args, call->span);
} else {
return Call(call->op, {x, value}, call->attrs, call->type_args, call->span);
}
}
return post;
}

private:
/*! \brief binary argument */
DFPattern x_;
/*! \brief data ops get shape from */
DFPattern data_;
/*! \brief constant input */
DFPattern value_;
/*! \brief full op */
DFPattern full_;
/*! \brief ones op */
DFPattern ones_;
/*! \brief zeros op */
DFPattern zeros_;
};

/*!
Expand All @@ -78,22 +161,24 @@ class SimplifyReshape {
class ExprSimplifier {
public:
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
CreateCallback(SimplifyReshape());
CreateCallback(FullElementwise());
}
template <typename T>
void CreateCallback(const T& pattern) {
auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = simplify_reshape_.callback(pre, post, node_map);
*rv = pattern.callback(pre, post, node_map);
};
callbacks_.push_back(
DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
}

Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }

private:
IRModule mod_;
/*! \brief Simplify reshape pattern */
SimplifyReshape simplify_reshape_;
/*! \brief Callbacks for expr simplification */
Array<DFPatternCallback> callbacks_;
};
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ def test_no_match_op_attr():
x = relay.var("x")
y = relay.var("y")
assert not op_pat.match(x - y)
z = relay.var("z")
assert not op_pat.match(relay.Let(z, x + y, z))


def test_match_func_attr():
Expand Down
16 changes: 0 additions & 16 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,6 @@ def expected(dtype):
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_full():
c_shape = (8, 9, 10)

def before():
dtype = "float32"
return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)

def expected():
# expect no changes
return before()

zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_batch_norm():
def expected():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,70 @@ def symbolic():
assert tvm.ir.structural_equal(zz, after)


def test_simplify_full_elementwise():
def validate(shape, value, dtype):
def before_left(x, elem_op, full):
return elem_op(full, x)

def after_left(x, elem_op, value):
return elem_op(relay.const(value, dtype), x)

def before_right(x, elem_op, full):
return elem_op(x, full)

def after_right(x, elem_op, value):
return elem_op(x, relay.const(value, dtype))

x = relay.var("x", shape=shape, dtype=dtype)
elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide]
full_ops = []
if value == 0:
full_ops.append(relay.zeros(shape, dtype))
full_ops.append(relay.zeros_like(x))
if value == 1:
full_ops.append(relay.ones(shape, dtype))
full_ops.append(relay.ones_like(x))
else:
full_ops.append(relay.full(relay.const(value, dtype), shape))
full_ops.append(relay.full_like(x, relay.const(value, dtype)))
for op in elem_ops:
for full in full_ops:
z = before_left(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(after_left(x, op, value), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

z = before_right(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(after_right(x, op, value), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

# Test the case in which x is broadcast to full's shape
full_ops = []
if value == 0:
full_ops.append(relay.zeros(shape * 2, dtype))
if value == 1:
full_ops.append(relay.ones(shape * 2, dtype))
else:
full_ops.append(relay.full(relay.const(value, dtype), shape * 2))
for op in elem_ops:
for full in full_ops:
z = before_left(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(before_left(x, op, full), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

z = before_right(x, op, full)
zz = run_opt_pass(z, transform.SimplifyExpr())
after = run_opt_pass(before_right(x, op, full), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

for shape in [[10], [10, 10], [10, 10, 10]]:
for dtype in ["float32", "int32"]:
for value in [0, 1, 2]:
validate(shape, value, dtype)


if __name__ == "__main__":
test_simplify_reshape()
test_simplify_full_elementwise()

0 comments on commit fbbd3f7

Please sign in to comment.