Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][PASS] FoldScaleAxis Forward #2020

Merged
merged 4 commits into from
Oct 28, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
// use axis to make the name numpy compatible.
Array<Integer> axis;

TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
TVM_ATTR_FIELD(axis)
.describe("The axis to squeeze in the input tensor."
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
}
}; // struct SqueezeAttrs

Expand Down
26 changes: 26 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node";
return this->checked_type_;
}
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
Expand Down Expand Up @@ -388,6 +400,20 @@ class TupleGetItemNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
14 changes: 11 additions & 3 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expr Mutate(const Expr& expr);
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expand All @@ -161,15 +168,16 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);

private:
protected:
/*! \brief Internal map used for memoization. */
tqchen marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
Expand Down
43 changes: 27 additions & 16 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,42 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level);
}

/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
tqchen marked this conversation as resolved.
Show resolved Hide resolved
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}

static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);

private:
// friend class
friend class GenericOpMap;
friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
};

/*!
Expand Down Expand Up @@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();

if (!op) {
return false;
}

const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;

const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}

return true;
return op != nullptr && op->IsPrimitiveOp();
}

} // namespace relay
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .expr import Expr
from .ty import Type


def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.

Expand All @@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)


def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.

Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.

Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)


def well_formed(expr):
"""Check that each Var is only bound once (well formed).

Expand Down Expand Up @@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
"""
return bool(_make._alpha_equal(lhs, rhs))


def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
Expand All @@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
"""
return bool(_make._graph_equal(lhs, rhs))


def structural_hash(value):
"""Hash a Relay expression structurally.

Expand Down
13 changes: 6 additions & 7 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,26 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes))


def squeeze(data, axes=None):
def squeeze(data, axis=None):
"""Squeeze axes in the array.

Parameters
----------
data : relay.Expr
data : tvm.relay.Expr
The input data to the operator.

axes : None or List[int]
axis : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
If axes = None, remove all axis of dimensions 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the comment also

Copy link
Member Author

@tqchen tqchen Oct 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not get what do you mean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axes should be changed to axis here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha

Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.

Returns
-------
result : relay.Expr
result : tvm.relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
return _make.squeeze(data, axis)


def reshape(data, newshape):
Expand Down
20 changes: 15 additions & 5 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,23 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;

// skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;

if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
namespace tvm {
namespace relay {

Expr ExprMutator::Mutate(const Expr& expr) {
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprMutator::VisitExpr(expr);
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
Expand Down
14 changes: 6 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);

Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand All @@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
Expand All @@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
for (const auto& e : param->axis) {
original_shape.at(e->value).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
Expand Down
Loading