Skip to content

Commit

Permalink
GetConstructor definition in module and change op comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Mar 23, 2020
1 parent e8e90f9 commit 2f35ee8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
8 changes: 8 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ class IRModuleNode : public Object {
*/
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Find constructor of ADT using name
* \param adt name of the ADT the constructor belongs to
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;

/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
Expand Down
12 changes: 12 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}

Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}

LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}

tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
Expand Down
44 changes: 14 additions & 30 deletions src/relay/transforms/gradient_cell.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,6 @@
namespace tvm {
namespace relay {

/*!
* \brief Get constructor of GradCell TypeDef with name_hint
*
* module must have TypeDefinition of GradCell (defined in gradient.rly)
*/
Constructor getGradCellConstructor(IRModule module, std::string name_hint) {
TypeData gradCell = module->LookupTypeDef("GradCell");
for (Constructor c : gradCell->constructors) {
if (name_hint.compare(c->name_hint) == 0) {
return c;
}
}

LOG(FATAL) << "Constructor " << name_hint << "not found in GradCell typedata.";
throw std::runtime_error("Constructor not found in GradCell typedata");
}

/*!
* \brief Visitor to wrap inputs
*/
Expand All @@ -92,7 +75,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {

Expr wrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return CallNode::make(getGradCellConstructor(module_, "Raw"),
return CallNode::make(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
Expand Down Expand Up @@ -191,14 +174,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
}

Expr VisitExpr_(const ConstantNode* op) final {
return CallNode::make(getGradCellConstructor(module_, "Raw"),
return CallNode::make(module_->GetConstructor("GradCell", "Raw"),
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
}

Expr VisitExpr_(const CallNode* call_node) final {
// optimize operators
if (auto* op = (call_node->op).as<OpNode>()) {
if (op->name.compare("add") == 0 && call_node->args.size() == 2 &&
Expr op_expr = GetRef<Op>(op);
if (op_expr == Op::Get("add") && call_node->args.size() == 2 &&
AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
// case: "add" between two tensors of the same size
const auto addFunc = module_->GetGlobalVar("AddGradCell");
Expand All @@ -217,7 +201,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
args.push_back(VisitExpr(expr));
}
return CallNode::make(addFunc, args, Attrs(), {paramType});
} else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 &&
} else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 &&
AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
// case: "multiply" between two tensors of the same size
const auto multFunc = module_->GetGlobalVar("MultiplyGradCell");
Expand All @@ -237,17 +221,17 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
args.push_back(VisitExpr(expr));
}
return CallNode::make(multFunc, args, Attrs(), {paramType});
} else if (op->name.compare("ones") == 0) {
} else if (op_expr == Op::Get("ones")) {
// ones operator, use One constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return CallNode::make(getGradCellConstructor(module_, "One"),
return CallNode::make(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()});
} else if (op->name.compare("zeros") == 0) {
} else if (op_expr == Op::Get("zeros")) {
// zeros operator, use Zero constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return CallNode::make(getGradCellConstructor(module_, "Zero"),
return CallNode::make(module_->GetConstructor("GradCell", "Zero"),
{func}, Attrs(), {call_node->checked_type()});
}

Expand All @@ -264,18 +248,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {

const Expr tensorRes = CallNode::make(call_node->op, args);

if (op->name.compare("ones_like") == 0) {
if (op_expr == Op::Get("ones_like")) {
Expr onesFunction = Function({}, tensorRes,
{call_node->checked_type()}, Array<TypeVar>());
return CallNode::make(getGradCellConstructor(module_, "One"),
return CallNode::make(module_->GetConstructor("GradCell", "One"),
{onesFunction}, Attrs(), {call_node->checked_type()});
} else if (op->name.compare("zeros_like") == 0) {
} else if (op_expr == Op::Get("zeros_like")) {
Expr zerosFunction = Function({}, tensorRes,
{call_node->checked_type()}, Array<TypeVar>());
return CallNode::make(getGradCellConstructor(module_, "Zero"),
return CallNode::make(module_->GetConstructor("GradCell", "Zero"),
{zerosFunction}, Attrs(), {call_node->checked_type()});
}
return CallNode::make(getGradCellConstructor(module_, "Raw"), {tensorRes},
return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes},
Attrs(), {call_node->checked_type()});
}
// call-> op is not a relay op
Expand Down

0 comments on commit 2f35ee8

Please sign in to comment.