Skip to content

Commit

Permalink
Merge pull request #7 from MarisaKirisame/type_of
Browse files Browse the repository at this point in the history
Type of
  • Loading branch information
jroesch authored Feb 10, 2019
2 parents 5a1c2f4 + ded0621 commit 1aa0b5e
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 7 deletions.
31 changes: 31 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,37 @@ std::string RelayPrint(
const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*!
* \brief User defined type relation, is an input-output relation on types.
*/
class TypeOf;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the module.
*/
class TypeOfNode : public TypeNode {
public:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the module.
*/
relay::Expr expr;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}

TVM_DLL static TypeOf make(relay::Expr expr);

static constexpr const char* _type_key = "relay.TypeOf";
TVM_DECLARE_NODE_TYPE_INFO(TypeOfNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeOf, TypeOfNode, Type);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
scalar_type = ty.scalar_type
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
TypeOf = ty.TypeOf

# Expr
Expr = expr.Expr
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .op import register_gradient
from .op import schedule_injective, OpPattern
from .transform import collapse_sum_like
from .tensor import negative
from .tensor import negative, zeros_like, ones_like


def add_grad(orig, grad):
Expand All @@ -29,6 +29,12 @@ def multiply_grad(orig, grad):

register_gradient("multiply", multiply_grad)

def take_grad(orig, grad):
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]


register_gradient("take", take_grad)

schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective

Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)

@register_relay_node
class TypeOf(Type):
def __init__(self, expr):
self.__init_handle_by_constructor__(_make.TypeOf, expr)

def scalar_type(dtype):
"""Creates a scalar type.
Expand Down
6 changes: 3 additions & 3 deletions src/relay/ir/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
std::endl <<
rang::style::reset;

// for (auto pair : err_map) {
// std::cout << "Key: " << pair.first << " Value: " << pair.second << std::endl;
// }
for (auto pair : err_map) {
std::cout << "Key: " << pair.first << std::endl << " Value: " << pair.second << std::endl;
}

// We then call into the Relay printer to generate the program.
//
Expand Down
18 changes: 18 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,24 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
*ret = temp->Realize();
});

TypeOf TypeOfNode::make(relay::Expr expr) {
NodePtr<TypeOfNode> n = make_node<TypeOfNode>();
n->expr = std::move(expr);
return TypeOf(n);
}

TVM_REGISTER_NODE_TYPE(TypeOfNode);

TVM_REGISTER_API("relay._make.TypeOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeOfNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeOfNode>([](const TypeOfNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeOf(" << node->expr << ")";
});

} // namespace relay
} // namespace tvm
1 change: 1 addition & 0 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleTypeNode(" << node->fields << ")";
});


} // namespace relay
} // namespace tvm
6 changes: 6 additions & 0 deletions src/relay/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
}
}

void TypeVisitor::VisitType_(const TypeOfNode* op) {}

// Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
Expand Down Expand Up @@ -172,6 +174,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) {
return GetRef<Type>(op);
}

Type TypeMutator::VisitType_(const TypeOfNode* op) {
return GetRef<Type>(op);
}

// Implements bind.
class TypeBinder : public TypeMutator {
public:
Expand Down
4 changes: 4 additions & 0 deletions src/relay/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeOfNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;

virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
Expand All @@ -93,6 +94,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeOfNode);
return vtable;
}
};
Expand All @@ -111,6 +113,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const TypeOfNode* op) override;
};

// Mutator that transform a type to another one.
Expand All @@ -125,6 +128,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const TypeOfNode* op) override;

private:
Array<Type> MutateArray(Array<Type> arr);
Expand Down
4 changes: 4 additions & 0 deletions src/relay/pass/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
return Kind::kTypeData;
}

Kind VisitType_(const TypeOfNode* op) override {
return kType;
}

Kind Check(const Type& t) {
return this->VisitType(t);
}
Expand Down
26 changes: 24 additions & 2 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ struct ResolvedTypeInfo {
Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
};

class TypeInferencer;
Type ExpandTypeOf(TypeInferencer* infer, const Type& t);

//
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
Expand All @@ -83,6 +86,7 @@ struct ResolvedTypeInfo {
//
class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
private PatternFunctor<void(const Pattern&, const Type&)> {
friend struct ExpandTypeOfMutator;
public:
// constructors

Expand Down Expand Up @@ -116,12 +120,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;

inline Type ExpandTypeOf(const Type& t) {
return ::tvm::relay::ExpandTypeOf(this, t);
}

// Perform unification on two types and report the error at the expression
// or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver

try {
return solver_.Unify(t1, t2, expr);
return solver_.Unify(ExpandTypeOf(t1), ExpandTypeOf(t2), expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
Expand Down Expand Up @@ -157,7 +166,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Visitor Logic
Type VisitExpr_(const VarNode* op) final {
if (op->type_annotation.defined()) {
return op->type_annotation;
return ExpandTypeOf(op->type_annotation);
} else {
return IncompleteTypeNode::make(Kind::kType);
}
Expand Down Expand Up @@ -718,6 +727,19 @@ Function InferType(const Function& func,
return Downcast<Function>(func_ret);
}

struct ExpandTypeOfMutator : TypeMutator {
TypeInferencer* infer;
Type VisitType_(const TypeOfNode* type_of) final {
return this->infer->GetType(type_of->expr);
}
ExpandTypeOfMutator(TypeInferencer* infer) : infer(infer) {}
};

inline Type ExpandTypeOf(TypeInferencer* infer, const Type& t) {
auto expand = ExpandTypeOfMutator(infer);
return expand.VisitType(t);
}

TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = InferType(args[0], args[1]);
Expand Down
1 change: 0 additions & 1 deletion src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
}
return TypeCallNode::make(func, args);
}

private:
TypeSolver* solver_;
};
Expand Down

0 comments on commit 1aa0b5e

Please sign in to comment.