From f3b58c395082cf0025877c5e4ce087e83ed7b1f4 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 15 Jan 2019 18:45:52 -0800 Subject: [PATCH 01/61] First pass on ADTs --- include/tvm/relay/adt.h | 228 ++++++++++++++++++++++++ include/tvm/relay/expr_functor.h | 14 ++ include/tvm/relay/interpreter.h | 22 +++ include/tvm/relay/module.h | 40 ++++- include/tvm/relay/op.h | 4 +- include/tvm/relay/pattern_functor.h | 143 +++++++++++++++ include/tvm/relay/type.h | 78 ++++++-- python/tvm/relay/__init__.py | 13 ++ python/tvm/relay/adt.py | 166 +++++++++++++++++ python/tvm/relay/backend/interpreter.py | 7 + python/tvm/relay/expr_functor.py | 12 ++ python/tvm/relay/module.py | 73 ++++++-- python/tvm/relay/prelude.py | 114 ++++++++++++ python/tvm/relay/ty.py | 58 ++++++ src/relay/backend/interpreter.cc | 71 +++++++- src/relay/ir/adt.cc | 161 +++++++++++++++++ src/relay/ir/alpha_equal.cc | 89 ++++++++- src/relay/ir/expr_functor.cc | 39 ++++ src/relay/ir/hash.cc | 77 +++++++- src/relay/ir/module.cc | 81 ++++++++- src/relay/ir/pattern_functor.cc | 75 ++++++++ src/relay/ir/text_printer.cc | 53 ++++++ src/relay/ir/type.cc | 51 +++++- src/relay/ir/type_functor.cc | 33 ++++ src/relay/ir/type_functor.h | 14 ++ src/relay/pass/fuse_ops.cc | 9 + src/relay/pass/kind_check.cc | 3 +- src/relay/pass/let_list.h | 2 +- src/relay/pass/type_infer.cc | 101 ++++++++++- src/relay/pass/util.cc | 1 + tests/python/relay/test_adt.py | 138 ++++++++++++++ tests/python/relay/test_typecall.py | 25 +++ 32 files changed, 1930 insertions(+), 65 deletions(-) create mode 100644 include/tvm/relay/adt.h create mode 100644 include/tvm/relay/pattern_functor.h create mode 100644 python/tvm/relay/adt.py create mode 100644 python/tvm/relay/prelude.py create mode 100644 src/relay/ir/adt.cc create mode 100644 src/relay/ir/pattern_functor.cc create mode 100644 tests/python/relay/test_adt.py create mode 100644 tests/python/relay/test_typecall.py diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h new file mode 100644 index 000000000000..c0fcc7629c51 --- /dev/null +++ b/include/tvm/relay/adt.h @@ -0,0 +1,228 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/adt.h + * \brief Algebraic data types for Relay + */ +#ifndef TVM_RELAY_ADT_H_ +#define TVM_RELAY_ADT_H_ + +#include +#include +#include +#include "./base.h" +#include "./type.h" +#include "./expr.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type for declaring relay pattern. */ +class PatternNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Pattern"; + TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); +}; + +/*! + * \brief Pattern is the base type for an ADT match pattern in Relay. + * + * Given an ADT value, a pattern might accept it and bind the pattern variable to some value + * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. + * + * ADT pattern matching thus takes a list of values and bings to the first that accepts the value. + */ +class Pattern : public NodeRef { + public: + Pattern() {} + explicit Pattern(NodePtr p) : NodeRef(p) {} + + using ContainerType = PatternNode; +}; + +/*! \brief A wildcard pattern: Accepts all input and binds nothing. */ +class PatternWildcard; +/*! \brief PatternWildcard container node */ +class PatternWildcardNode : public PatternNode { + public: + PatternWildcardNode() {} + + TVM_DLL static PatternWildcard make(); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternWildcard"; + TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); + +/*! \brief A var pattern. Accept all input and bind to a var. */ +class PatternVar; +/*! \brief PatternVar container node */ +class PatternVarNode : public PatternNode { + public: + PatternVarNode() {} + + tvm::relay::Var var; + + TVM_DLL static PatternVar make(tvm::relay::Var var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternVar"; + TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); + +/*! + * \brief ADT constructor. + * Constructors compare by pointer equality. + */ +class Constructor; +/*! \brief Constructor container node. */ +class ConstructorNode : public ExprNode { + public: + /*! \brief The name (only a hint) */ + std::string name_hint; + /*! \brief Input to the constructor. */ + tvm::Array inp; + /*! \brief The datatype the constructor will construct. */ + GlobalTypeVar belong_to; + mutable int tag = -1; + + ConstructorNode() {} + + TVM_DLL static Constructor make(std::string name_hint, + tvm::Array inp, + GlobalTypeVar belong_to); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + v->Visit("inp", &inp); + v->Visit("belong_to", &belong_to); + v->Visit("tag", &tag); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.Constructor"; + TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); + +/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ +class PatternConstructor; +/*! \brief PatternVar container node */ +class PatternConstructorNode : public PatternNode { + public: + Constructor con; + tvm::Array pat; + + PatternConstructorNode() {} + + TVM_DLL static PatternConstructor make(Constructor con, tvm::Array var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("con", &con); + v->Visit("pat", &pat); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternConstructor"; + TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); + +/*! + * \brief Stores all data for an Algebraic Data Type (ADT). + */ +class TypeData; +/*! \brief TypeData container node */ +class TypeDataNode : public TypeNode { + public: + /*! + * \brief The header is simply the name of the ADT. + * We adopt nominal typing for ADT definitions; + * that is, differently-named ADT definitions with same constructors + * have different types. + */ + GlobalTypeVar header; + /*! \brief The type variables (to allow for polymorphism). */ + tvm::Array tv; + /*! \brief The constructors. */ + tvm::Array constructors; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("header", &header); + v->Visit("tv", &tv); + v->Visit("constructors", &constructors); + v->Visit("span", &span); + } + + TVM_DLL static TypeData make(GlobalTypeVar header, + tvm::Array tv, + tvm::Array constructors); + + static constexpr const char* _type_key = "relay.TypeData"; + TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); + +/*! \brief A clause in a match expression. */ +class Clause; +/*! \brief Clause container node. */ +class ClauseNode : public Node { + public: + Pattern lhs; + Expr rhs; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + } + + TVM_DLL static Clause make(Pattern lhs, Expr rhs); + + static constexpr const char* _type_key = "relay.Clause"; + TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); + +/*! \brief ADT pattern matching exression. */ +class Match; +/*! \brief Match container node. */ +class MatchNode : public ExprNode { + public: + Expr data; + + tvm::Array pattern; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("pattern", &pattern); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Match make(Expr data, tvm::Array pattern); + + static constexpr const char* _type_key = "relay.Match"; + TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ADT_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e7b66bc1bbde..fd68139495b4 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -10,6 +10,7 @@ #include #include #include "./expr.h" +#include "./adt.h" #include "./op.h" #include "./error.h" @@ -92,6 +93,8 @@ class ExprFunctor { virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -114,6 +117,8 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); + RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); return vtable; } }; @@ -142,7 +147,11 @@ class ExprVisitor void VisitExpr_(const RefCreateNode* op) override; void VisitExpr_(const RefReadNode* op) override; void VisitExpr_(const RefWriteNode* op) override; + void VisitExpr_(const ConstructorNode* op) override; + void VisitExpr_(const MatchNode* op) override; virtual void VisitType(const Type& t); + virtual void VisitClause(const Clause& c); + virtual void VisitPattern(const Pattern& c); protected: // Internal visiting counter @@ -180,6 +189,9 @@ class ExprMutator Expr VisitExpr_(const RefCreateNode* op) override; Expr VisitExpr_(const RefReadNode* op) override; Expr VisitExpr_(const RefWriteNode* op) override; + Expr VisitExpr_(const ConstructorNode* op) override; + Expr VisitExpr_(const MatchNode* op) override; + /*! * \brief Used to visit the types inside of expressions. * @@ -188,6 +200,8 @@ class ExprMutator * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); + virtual Clause VisitClause(const Clause& c); + virtual Pattern VisitPattern(const Pattern& c); protected: /*! \brief Internal map used for memoization. */ diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 08aeef1827b6..f235a20065af 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -160,6 +160,28 @@ struct RefValueNode : ValueNode { RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); +/*! \brief An ADT constructor value. */ +class ConValue; + +struct ConValueNode : ValueNode { + Constructor con; + + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("con", &con); + v->Visit("fields", &fields); + } + + TVM_DLL static ConValue make(Constructor con, + tvm::Array fields); + + static constexpr const char* _type_key = "relay.ConValue"; + TVM_DECLARE_NODE_TYPE_INFO(ConValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(ConValue, ConValueNode, Value); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_INTERPRETER_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 45ccfe3a8089..988c7402a506 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -35,13 +36,14 @@ struct Module; * * The functional style allows users to construct custom * environments easily, for example each thread can store - * an Module while auto-tuning. + * a Module while auto-tuning. * */ class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; + tvm::Map type_definitions; /*! \brief The entry function (i.e. "main"). */ GlobalVar entry_func; @@ -50,21 +52,31 @@ class ModuleNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); + v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); v->Visit("entry_func", &entry_func); + v->Visit("global_type_var_map_", &global_type_var_map_); } - TVM_DLL static Module make(tvm::Map global_funcs); + TVM_DLL static Module make(tvm::Map global_funcs, + tvm::Map global_type_defs); /*! * \brief Add a function to the global environment. - * \param var The name of the global function. + * \param var The var of the global function. * \param func The function. * \param update Controls whether you can replace a definition in the * environment. */ void Add(const GlobalVar& var, const Function& func, bool update = false); + /*! + * \brief Add a type-level definition to the global environment. + * \param var The var of the global type definition. + * \param type The type definition. + */ + void AddDef(const GlobalTypeVar& var, const TypeData& type); + /*! * \brief Add a function to the global environment. * \param var The name of the global function. @@ -94,6 +106,13 @@ class ModuleNode : public RelayNode { */ GlobalVar GetGlobalVar(const std::string& str); + /*! + * \brief Look up a global function by its name. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ + GlobalTypeVar GetGlobalTypeVar(const std::string& str); + /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. @@ -108,6 +127,20 @@ class ModuleNode : public RelayNode { */ Function Lookup(const std::string& name); + /*! + * \brief Lookup a global type definition by its variable. + * \param var The var of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const GlobalTypeVar& var); + + /*! + * \brief Lookup a global type definition by its name. + * \param var The name of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const std::string& var); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -137,6 +170,7 @@ class ModuleNode : public RelayNode { * ensures global uniqueness. */ tvm::Map global_var_map_; + tvm::Map global_type_var_map_; }; struct Module : public NodeRef { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0fd54ff5b8fa..583491ca2613 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel( std::string input_name_prefix = "in"; for (int i = 0; i < get()->num_inputs; i++) { auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); + auto param = TypeVarNode::make(name, Kind::kType); type_params.push_back(param); arg_types.push_back(param); } @@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel( Array ty_call_args = arg_types; // Add output type. - auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); + auto out_param = TypeVarNode::make("out", Kind::kType); type_params.push_back(out_param); // this will trigger copy on write. ty_call_args.push_back(out_param); diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h new file mode 100644 index 000000000000..f9833201ea0a --- /dev/null +++ b/include/tvm/relay/pattern_functor.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pattern_functor.h + * \brief A more powerful visitor on ADT patterns that enables defining + * arbitrary function signatures with type-based dispatch on first argument. + */ +#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_PATTERN_FUNCTOR_H_ + +#include +#include +#include "./expr.h" +#include "./op.h" +#include "./error.h" +#include "./adt.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor on ADT patterns that dispatches on its first argument. + * You can use this as a more powerful visitor, since it allows you to + * define the types of further arguments to VisitPattern. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Pattern&, + * Args...) + */ +template +class PatternFunctor; + +// functions to be overriden. +#define PATTERN_FUNCTOR_DEFAULT \ + { return VisitPatternDefault_(op, std::forward(args)...); } + +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class PatternFunctor { + private: + using TSelf = PatternFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~PatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Pattern& n, Args... args) { + return VisitPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitPattern(const Pattern& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitPattern_(const PatternWildcardNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPatternDefault_(const Node* op, Args...) { + throw Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); + return vtable; + } +}; + +/*! \brief A simple visitor wrapper around PatternFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new pattern. + */ +class PatternVisitor : public ::tvm::relay::PatternFunctor { + public: + void VisitPattern_(const PatternWildcardNode* op) override; + void VisitPattern_(const PatternVarNode* op) override; + void VisitPattern_(const PatternConstructorNode* op) override; + virtual void VisitType(const Type& t); + virtual void VisitVar(const Var& v); + virtual void VisitConstructor(const Constructor& c); +}; + +/*! \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +class PatternMutator + : public ::tvm::relay::PatternFunctor { + public: + Pattern Mutate(const Pattern& pat); + Pattern VisitPattern_(const PatternWildcardNode* op) override; + Pattern VisitPattern_(const PatternVarNode* op) override; + Pattern VisitPattern_(const PatternConstructorNode* op) override; + /*! \brief Used to visit the types inside of patterns. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Var VisitVar(const Var& v); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Constructor VisitConstructor(const Constructor& c); + private: + std::unordered_map var_map_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 0ee265e5f3b0..600075bcb8cc 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -98,6 +98,15 @@ class TensorTypeNode : public BaseTensorTypeNode { RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); +/*! \brief possible kinds of Type */ +enum Kind : int { + /*! \brief template variable in shape expression */ + kType = 0, + kShapeVar = 1, + kBaseType = 2, + kShape = 3 +}; + /*! * \brief Type parameter in the function. * This can be viewed as template parameter in c++ template function. @@ -119,14 +128,6 @@ class TypeVar; /*! \brief TypeVar container node */ class TypeVarNode : public TypeNode { public: - /*! \brief possible kinds of TypeVar */ - enum Kind : int { - /*! \brief template variable in shape expression */ - kType = 0, - kShapeVar = 1, - kBaseType = 2, - kShape = 3 - }; /*! * \brief The variable itself is only meaningful when * kind is ShapeVar, otherwise, we only use the name. @@ -149,6 +150,63 @@ class TypeVarNode : public TypeNode { RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); +/*! + * \brief A global type variable that is used for defining new types or type aliases. + */ +class GlobalTypeVar; +/*! \brief GlobalTypeVar container node */ +class GlobalTypeVarNode : public TypeNode { + public: + /*! + * \brief The variable itself is only meaningful when + * kind is ShapeVar; otherwise, we only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.GlobalTypeVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); + +/*! + * \brief Type application. + */ +class TypeCall; +/*! \brief TypeCall container node */ +class TypeCallNode : public TypeNode { + public: + /*! + * \brief The type-level function. + */ + Type func; + /*! \brief The arguments. */ + tvm::Array args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("span", &span); + } + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + /*! * \brief IncompleteType. * This is intermediate values that is used during type inference. @@ -162,14 +220,14 @@ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - TypeVarNode::Kind kind; + Kind kind; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); v->Visit("span", &span); } - TVM_DLL static IncompleteType make(TypeVarNode::Kind kind); + TVM_DLL static IncompleteType make(Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 0af164bc7a73..fe00877c0fb0 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -7,8 +7,10 @@ from . import expr from . import expr_functor from . import module +from . import adt from . import ir_pass from .build_module import build, build_config, create_executor, optimize +from . import prelude from . import parser from . import debug @@ -45,6 +47,8 @@ IncompleteType = ty.IncompleteType scalar_type = ty.scalar_type RefType = ty.RefType +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall # Expr Expr = expr.Expr @@ -61,6 +65,15 @@ RefRead = expr.RefRead RefWrite = expr.RefWrite +# ADT +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match + # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py new file mode 100644 index 000000000000..7e25ec16fb52 --- /dev/null +++ b/python/tvm/relay/adt.py @@ -0,0 +1,166 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Algebraic data types in Relay.""" +from .base import RelayNode, register_relay_node, NodeBase +from . import _make +from .ty import Type +from .expr import Expr + + +class Pattern(RelayNode): + """Base type for pattern matching constructs.""" + pass + +@register_relay_node +class PatternWildcard(Pattern): + """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" + + def __init__(self): + """Constructs a wildcard pattern. + + Parameters + ---------- + None + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + self.__init_handle_by_constructor__(_make.PatternWildcard) + + +@register_relay_node +class PatternVar(Pattern): + """Variable pattern in Relay: Matches anything and binds it to the variable.""" + + def __init__(self, var): + """Construct a variable pattern. + + Parameters + ---------- + var: tvm.relay.Var + + Returns + ------- + pv: PatternVar + A variable pattern. + """ + self.__init_handle_by_constructor__(_make.PatternVar, var) + + +@register_relay_node +class PatternConstructor(Pattern): + """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" + + def __init__(self, con, pat=None): + """Construct a constructor pattern. + + Parameters + ---------- + con: Constructor + The constructor. + pat: Optional[List[Pattern]] + Optional subpatterns: for each field of the constructor, + match to the given subpattern (treated as a variable pattern by default). + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + if pat is None: + pat = [] + self.__init_handle_by_constructor__(_make.PatternConstructor, con, pat) + + +@register_relay_node +class Constructor(Expr): + """Relay ADT constructor.""" + + def __init__(self, name_hint, inp, belong_to): + """Defines an ADT constructor. + + Parameters + ---------- + name_hint : str + Name of constructor (only a hint). + inp : List[Type] + Input types. + belong_to : tvm.relay.GlobalTypeVar + Denotes which ADT the constructor belongs to. + + Returns + ------- + con: Constructor + A constructor. + """ + self.__init_handle_by_constructor__(_make.Constructor, name_hint, inp, belong_to) + + +@register_relay_node +class TypeData(Type): + """Stores the definition for an Algebraic Data Type (ADT) in Relay.""" + + def __init__(self, header, tv, constructors): + """Defines a TypeData object. + + Parameters + ---------- + header: tvm.relay.GlobalTypeVar + The name of the ADT. + ADTs with the same constructors but different names are + treated as different types. + tv: List[TypeVar] + Type variables that appear in constructors. + constructors: List[tvm.relay.Constructor] + The constructors for the ADT. + + Returns + ------- + type_data: TypeData + The adt declaration. + """ + self.__init_handle_by_constructor__(_make.TypeData, header, tv, constructors) + + +@register_relay_node +class Clause(NodeBase): + """Clause for pattern matching in Relay.""" + + def __init__(self, lhs, rhs): + """Construct a clause. + + Parameters + ---------- + lhs: tvm.relay.Pattern + Left-hand side of match clause. + rhs: tvm.relay.Expr + Right-hand side of match clause. + + Returns + ------- + clause: Clause + The Clause. + """ + self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) + + +@register_relay_node +class Match(Expr): + """Pattern matching expression in Relay.""" + + def __init__(self, data, pattern): + """Construct a Match. + + Parameters + ---------- + data: tvm.relay.Expr + The value being deconstructed and matched. + pattern: [tvm.relay.Clause] + The pattern match clauses. + Returns + ------- + match: tvm.relay.Expr + The match expression. + """ + self.__init_handle_by_constructor__(_make.Match, data, pattern) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index b21eab185c28..732567375fcc 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -52,6 +52,13 @@ class Closure(Value): pass +@register_relay_node +class ConValue(Value): + def __init__(self, con, fields, types): + self.__init_handle_by_constructor__( + _make.ConValue, con, fields, types) + + @register_relay_node class TensorValue(Value): """A Tensor value produced by the interpreter.""" diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index b22a4e7562e2..199d66baa45a 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -2,6 +2,7 @@ """The expression functor of Relay.""" from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant +from .adt import Constructor, Match, Clause from .op import Op class ExprFunctor: @@ -47,6 +48,10 @@ def visit(self, expr): res = self.visit_ref_read(expr) elif isinstance(expr, RefWrite): res = self.visit_ref_write(expr) + elif isinstance(expr, Constructor): + res = self.visit_constructor(expr) + elif isinstance(expr, Match): + res = self.visit_match(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -96,6 +101,13 @@ def visit_ref_write(self, _): def visit_ref_read(self, _): raise NotImplementedError() + def visit_constructor(self, _): + raise NotImplementedError() + + def visit_match(self, _): + raise NotImplementedError() + + class ExprMutator(ExprFunctor): """ A functional visitor over Expr. diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 024c6baf7012..5812d9a2481b 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -6,6 +6,7 @@ from . import _module from . import expr as _expr +from . import ty as _ty @register_relay_node class Module(RelayNode): @@ -20,7 +21,7 @@ class Module(RelayNode): functions : dict, optional. Map of global var to Function """ - def __init__(self, functions=None): + def __init__(self, functions=None, type_definitions=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -32,28 +33,45 @@ def __init__(self, functions=None): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") mapped_funcs[k] = v functions = mapped_funcs - self.__init_handle_by_constructor__(_make.Module, functions) - - def __setitem__(self, var, func): - """Add a function to the module. + if type_definitions is None: + type_definitions = {} + elif isinstance(type_definitions, dict): + mapped_type_defs = {} + for k, v in type_definitions.items(): + if isinstance(k, _base.string_types): + k = _ty.GlobalTypeVar(k) + if not isinstance(k, _ty.GlobalTypeVar): + raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") + type_definitions = mapped_type_defs + self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) + + + def __setitem__(self, var, val): + """Add a mapping to the module. Parameters --------- var: GlobalVar - The global variable which names the function. + The global variable. - func: Function - The function. + val: Union[Function, Type] + The value. """ - return self._add(var, func) + return self._add(var, val) - def _add(self, var, func, update=False): - if isinstance(var, _base.string_types): - var = _expr.GlobalVar(var) - return _module.Module_Add(self, var, func, update) + def _add(self, var, val, update=False): + if isinstance(val, _expr.Function): + if isinstance(var, _base.string_types): + var = _expr.GlobalVar(var) + _make.Module_Add(self, var, val, update) + else: + assert isinstance(val, _ty.Type) + if isinstance(var, _base.string_types): + var = _ty.GlobalTypeVar(var) + _module.Module_AddDef(self, var, val) def __getitem__(self, var): - """Lookup a global function by name or by variable. + """Lookup a global definition by name or by variable. Parameters ---------- @@ -62,13 +80,15 @@ def __getitem__(self, var): Returns ------- - func: Function - The function referenced by :code:`var`. + val: Union[Function, Type] + The definition referenced by :code:`var` (either a function or type). """ if isinstance(var, _base.string_types): return _module.Module_Lookup_str(self, var) - else: + elif isinstance(var, _expr.GlobalVar): return _module.Module_Lookup(self, var) + else: + return _module.Module_LookupDef(self, var) def update(self, other): """Insert functions in another Module to current one. @@ -100,3 +120,22 @@ def get_global_var(self, name): tvm.TVMError if we cannot find corresponding global var. """ return _module.Module_GetGlobalVar(self, name) + + def get_global_type_var(self, name): + """Get a global type variable in the function by name. + + Parameters + ---------- + name: str + The name of the global type variable. + + Returns + ------- + global_type_var: GlobalTypeVar + The global variable mapped to :code:`name`. + + Raises + ------ + tvm.TVMError if we cannot find corresponding global type var. + """ + return _module.Module_GetGlobalTypeVar(self, name) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py new file mode 100644 index 000000000000..29fcc40b9361 --- /dev/null +++ b/python/tvm/relay/prelude.py @@ -0,0 +1,114 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Include some preloaded term/type definitions.""" +from .ty import GlobalTypeVar, TypeVar, FuncType +from .expr import Var, Function, GlobalVar +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard + +class Prelude: + """Contain standard definitions.""" + def __init__(self, mod): + self.mod = mod + self.nat = GlobalTypeVar("nat") + self.z = Constructor("z", [], self.nat) + self.s = Constructor("s", [self.nat()], self.nat) + mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) + + self.double = GlobalVar("double") + x = Var("x", self.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(self.z), self.z()) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.s(self.s(self.double(y)))) + mod[self.double] = Function([x], Match(x, [z_case, s_case])) + + self.add = GlobalVar("add") + x = Var("x", self.nat()) + y = Var("y", self.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(self.z), y) + s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), self.s(self.add(a, y))) + mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) + + self.l = GlobalTypeVar("list") + a = TypeVar("a") + self.nil = Constructor("nil", [], self.l) + self.cons = Constructor("cons", [a, self.l(a)], self.l) + mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + + self.length = GlobalVar("length") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + nil_case = Clause(PatternConstructor(self.nil), self.z()) + cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), + self.s(self.length(y))) + mod[self.length] = Function([x], Match(x, [nil_case, cons_case]), None, [a]) + + self.map = GlobalVar("map") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], b)) + x = Var("x", self.l(a)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.cons(f(y), self.map(f, z))) + mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), None, [a, b]) + + self.foldl = GlobalVar("foldl") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], a)) + av = Var("av", a) + bv = Var("bv", self.l(b)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), av) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.foldl(f, f(av, y), z)) + mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), None, [a, b]) + + self.tree = GlobalTypeVar("tree") + a = TypeVar("a") + self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) + mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + + self.foldr = GlobalVar("foldr") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], b)) + av = Var("av", self.l(a)) + bv = Var("bv", b) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), bv) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + f(y, self.foldr(f, bv, z))) + mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) + + self.sum = GlobalVar("sum") + a = Var("a", self.l(self.nat())) + mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + + self.tmap = GlobalVar("tmap") + a = TypeVar("a") + b = TypeVar("b") + t = Var("t", self.tree(a)) + f = Var("f", FuncType([a], b)) + x = Var("x", self.tree(a)) + y = Var("y") + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), + self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) + mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) + + self.size = GlobalVar("size") + a = TypeVar("a") + t = Var("t", self.tree(a)) + x = Var("x", self.tree(a)) + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), + self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + mod[self.size] = Function([t], Match(t, [rose_case]), self.nat(), [a]) + # cannot infer return type here diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index bed293d1e3ca..ee95a82fb90b 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -21,6 +21,19 @@ def same_as(self, other): """Compares two Relay types by referential equality.""" return super().__eq__(other) + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[relay.Type] + The arguments to the type call. + + Returns + ------- + call: relay.TypeCall + """ + return TypeCall(self, args) @register_relay_node class TensorType(Type): @@ -106,6 +119,51 @@ def __init__(self, var, kind=Kind.Type): self.__init_handle_by_constructor__(_make.TypeVar, var, kind) +@register_relay_node +class GlobalTypeVar(Type): + """A global type variable in Relay. + GlobalTypeVar is used to refer to the global type-level definitions + stored in the environment. + """ + + def __init__(self, var, kind=Kind.Type): + """Construct a GlobalTypeVar. + + Parameters + ---------- + var: tvm.Var + The tvm.Var which backs the type parameter. + kind: Kind, optional + The kind of the type parameter, Kind.Type by default. + + Returns + ------- + type_var: GlobalTypeVar + The global type variable. + """ + self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind) + + +@register_relay_node +class TypeCall(Type): + """Type-level function application in Relay.""" + + def __init__(self, func, args): + """Construct a TypeCall. + Parameters + ---------- + func: tvm.relay.Type + The function. + args: List[tvm.expr.Type] + The arguments. + Returns + ------- + type_call: TypeCall + The type function application. + """ + self.__init_handle_by_constructor__(_make.TypeCall, func, args) + + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 893e66b41b42..3a59387c7b4b 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -92,6 +93,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); +ConValue ConValueNode::make(Constructor con, + tvm::Array fields) { + NodePtr n = make_node(); + n->con = con; + n->fields = fields; + return ConValue(n); +} + +TVM_REGISTER_API("relay._make.ConValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConValueNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConValueNode* node, + tvm::IRPrinter* p) { + p->stream << "ConValueNode(" << node->con + << node->fields << ")"; +}); + /*! * \brief A stack frame in the Relay interpreter. * @@ -185,7 +206,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { // // Conversion to ANF is recommended before running the interpretation. class Interpreter : - public ExprFunctor { + public ExprFunctor, + PatternFunctor { public: Interpreter(Module mod, DLContext context, @@ -209,7 +231,7 @@ class Interpreter : } Value Eval(const Expr& expr) { - return (*this)(expr); + return VisitExpr(expr); } Value VisitExpr(const Expr& expr) final { @@ -401,6 +423,9 @@ class Interpreter : << "; operators should be removed by future passes; try " "fusing and lowering"; } + if (auto con = call->op.as()) { + return ConValueNode::make(GetRef(con), args); + } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); if (const ClosureNode* closure_node = fn_val.as()) { @@ -474,6 +499,44 @@ class Interpreter : } } + Value VisitExpr_(const MatchNode* op) final { + Value v = Eval(op->data); + for (const Clause& c : op->pattern) { + if (VisitPattern(c->lhs, v)) { + return VisitExpr(c->rhs); + } + } + LOG(FATAL) << "did not find any match"; + return Value(); + } + + bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { + const ConValueNode* cvn = v.as(); + CHECK(cvn) << "need to be a constructor for match"; + CHECK_NE(op->con->tag, -1); + CHECK_NE(cvn->con->tag, -1); + if (op->con->tag == cvn->con->tag) { + // todo(M.K.): should use ptr equality but it is broken + CHECK(op->pat.size() == cvn->fields.size()); + for (size_t i = 0; i < op->pat.size(); ++i) { + if (!VisitPattern(op->pat[i], cvn->fields[i])) { + return false; + } + } + return true; + } + return false; + } + + bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { + return true; + } + + bool VisitPattern_(const PatternVarNode* op, const Value& v) final { + extend(op->var, v); + return true; + } + InterpreterState get_state(Expr e = Expr()) const { InterpreterStateNode::Stack stack; for (auto fr : this->stack_.frames) { @@ -485,14 +548,14 @@ class Interpreter : } private: - // module + // Module Module mod_; // For simplicity we only run the interpreter on a single context. // Context to run the interpreter on. DLContext context_; // Target parameter being used by the interpreter. Target target_; - // value stack. + // Value stack. Stack stack_; // Backend compile engine. CompileEngine engine_; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc new file mode 100644 index 000000000000..4bcbc5f27e4b --- /dev/null +++ b/src/relay/ir/adt.cc @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/adt.cc + * \brief AST nodes for Relay algebraic data types (ADTs). + */ +#include +#include + +namespace tvm { +namespace relay { + +PatternWildcard PatternWildcardNode::make() { + NodePtr n = make_node(); + return PatternWildcard(n); +} + +TVM_REGISTER_NODE_TYPE(PatternWildcardNode); + +TVM_REGISTER_API("relay._make.PatternWildcard") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternWildcardNode::make(); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternWildcardNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternWildcardNode()"; +}); + +PatternVar PatternVarNode::make(tvm::relay::Var var) { + NodePtr n = make_node(); + n->var = std::move(var); + return PatternVar(n); +} + +TVM_REGISTER_NODE_TYPE(PatternVarNode); + +TVM_REGISTER_API("relay._make.PatternVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternVarNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternVarNode(" << node->var << ")"; +}); + +PatternConstructor PatternConstructorNode::make(Constructor con, tvm::Array pat) { + NodePtr n = make_node(); + n->con = std::move(con); + n->pat = std::move(pat); + return PatternConstructor(n); +} + +TVM_REGISTER_NODE_TYPE(PatternConstructorNode); + +TVM_REGISTER_API("relay._make.PatternConstructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternConstructorNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternConstructorNode(" << node->con + << ", " << node->pat << ")"; +}); + +Constructor ConstructorNode::make(std::string name_hint, + tvm::Array inp, + GlobalTypeVar belong_to) { + NodePtr n = make_node(); + n->name_hint = std::move(name_hint); + n->inp = std::move(inp); + n->belong_to = std::move(belong_to); + return Constructor(n); +} + +TVM_REGISTER_NODE_TYPE(ConstructorNode); + +TVM_REGISTER_API("relay._make.Constructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConstructorNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "ConstructorNode(" << node->name_hint << ", " + << node->inp << ", " << node->belong_to << ")"; +}); + +TypeData TypeDataNode::make(GlobalTypeVar header, + tvm::Array tv, + tvm::Array constructors) { + NodePtr n = make_node(); + n->header = std::move(header); + n->tv = std::move(tv); + n->constructors = std::move(constructors); + return TypeData(n); +} + +TVM_REGISTER_NODE_TYPE(TypeDataNode); + +TVM_REGISTER_API("relay._make.TypeData") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeDataNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeDataNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeDataNode(" << node->header << ", " << node->tv << ", " + << node->constructors << ")"; +}); + +Clause ClauseNode::make(Pattern lhs, Expr rhs) { + NodePtr n = make_node(); + n->lhs = std::move(lhs); + n->rhs = std::move(rhs); + return Clause(n); +} + +TVM_REGISTER_NODE_TYPE(ClauseNode); + +TVM_REGISTER_API("relay._make.Clause") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ClauseNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ClauseNode* node, + tvm::IRPrinter* p) { + p->stream << "ClauseNode(" << node->lhs << ", " + << node->rhs << ")"; + }); + +Match MatchNode::make(Expr data, tvm::Array pattern) { + NodePtr n = make_node(); + n->data = std::move(data); + n->pattern = std::move(pattern); + return Match(n); +} + +TVM_REGISTER_NODE_TYPE(MatchNode); + +TVM_REGISTER_API("relay._make.Match") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = MatchNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const MatchNode* node, + tvm::IRPrinter* p) { + p->stream << "MatchNode(" << node->data << ", " + << node->pattern << ")"; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index d0cc004994d4..2e3d2223181c 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include "type_functor.h" @@ -17,7 +18,8 @@ namespace relay { class AlphaEqualHandler: public AttrsEqualHandler, public TypeFunctor, - public ExprFunctor { + public ExprFunctor, + public PatternFunctor { public: explicit AlphaEqualHandler(bool map_free_var) : map_free_var_(map_free_var) {} @@ -160,7 +162,7 @@ class AlphaEqualHandler: } equal_map_[lhs->type_params[i]] = rhs->type_params[i]; // set up type parameter equal - if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { + if (lhs->type_params[i]->kind == Kind::kShapeVar) { // map variable equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; } @@ -215,6 +217,26 @@ class AlphaEqualHandler: return false; } + bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final { + return GetRef(op) == t2; + } + + bool VisitType_(const TypeCallNode* op, const Type& t2) final { + const TypeCallNode* pt = t2.as(); + if (pt == nullptr + || op->args.size() != pt->args.size() + || !TypeEqual(op->func, pt->func)) { + return false; + } + + for (size_t i = 0; i < op->args.size(); ++i) { + if (!TypeEqual(op->args[i], pt->args[i])) { + return false; + } + } + return true; + } + // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { @@ -261,11 +283,9 @@ class AlphaEqualHandler: bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { if (const GlobalVarNode* rhs = other.as()) { // use name equality for global var for now. - if (lhs->name_hint != rhs->name_hint) return false; - return true; - } else { - return false; + return lhs->name_hint == rhs->name_hint; } + return false; } bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { @@ -392,6 +412,63 @@ class AlphaEqualHandler: return false; } } + + bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final { + return GetRef(op) == e2; + } + + bool ClauseEqual(const Clause& l, const Clause& r) { + return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs); + } + + bool PatternEqual(const Pattern& l, const Pattern& r) { + return VisitPattern(l, r); + } + + bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) { + return r.as(); + } + + bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) { + if (const auto* r = e2.as()) { + return ExprEqual(op->var, r->var); + } + return false; + } + + bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) { + const auto* r = e2.as(); + if (r == nullptr + || !ExprEqual(op->con, r->con) + || op->pat.size() != r->pat.size()) { + return false; + } + + for (size_t i = 0; i < op->pat.size(); i++) { + if (!PatternEqual(op->pat[i], r->pat[i])) { + return false; + } + } + return true; + } + + bool VisitExpr_(const MatchNode* op, const Expr& e2) final { + const MatchNode* r = e2.as(); + + if (r == nullptr + || !ExprEqual(op->data, r->data) + || op->pattern.size() != r->pattern.size()) { + return false; + } + + for (size_t i = 0; i < op->pattern.size(); ++i) { + if (!ClauseEqual(op->pattern[i], r->pattern[i])) { + return false; + } + } + return true; + } + private: // whether to map open terms. bool map_free_var_{false}; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 9bdfa00ce298..c89c40dc1461 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -185,6 +185,24 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { } } +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { + return GetRef(c); +} + +Expr ExprMutator::VisitExpr_(const MatchNode* m) { + std::vector pattern; + for (const Clause& p : m->pattern) { + pattern.push_back(VisitClause(p)); + } + return MatchNode::make(VisitExpr(m->data), pattern); +} + +Clause ExprMutator::VisitClause(const Clause& c) { + return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs)); +} + +Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } + Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { @@ -267,6 +285,27 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->value); } +void ExprVisitor::VisitExpr_(const ConstructorNode* op) { + for (const Type& t : op->inp) { + this->VisitType(t); + } + this->VisitType(op->belong_to); +} + +void ExprVisitor::VisitExpr_(const MatchNode* op) { + this->VisitExpr(op->data); + for (const Clause& c : op->pattern) { + this->VisitClause(c); + } +} + +void ExprVisitor::VisitClause(const Clause& op) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs); +} + +void ExprVisitor::VisitPattern(const Pattern& p) { return; } + void ExprVisitor::VisitType(const Type& t) { return; } // visitor to implement apply diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d984bb051e43..75a2dc0aea28 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include @@ -19,6 +20,8 @@ class RelayHashHandler: public AttrsHashHandler, public TypeFunctor, public ExprFunctor { + public ExprFunctor, + public PatternFunctor { public: explicit RelayHashHandler() {} @@ -201,7 +204,7 @@ class RelayHashHandler: hash_map_[var] = hash; const auto* ty_param = var.as(); - if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) { + if (ty_param && ty_param->kind == Kind::kShapeVar) { hash_map_[ty_param->var] = hash; } return hash; @@ -236,7 +239,7 @@ class RelayHashHandler: } hash = Combine(hash, TypeHash(func->ret_type)); - hash = Combine(hash, ExprHash(func->body)); + hash = Combine(hash, ExprHash(func->body)); return hash; } @@ -249,6 +252,10 @@ class RelayHashHandler: hash = Combine(hash, ExprHash(arg)); } + for (auto t : call->type_args) { + hash = Combine(hash, TypeHash(t)); + } + hash = Combine(hash, AttrHash(call->attrs)); return hash; @@ -304,6 +311,72 @@ class RelayHashHandler: hash = Combine(hash, ExprHash(rn->value)); return hash; } + + size_t VisitExpr_(const MatchNode* mn) final { + size_t hash = std::hash()(MatchNode::_type_key); + hash = Combine(hash, ExprHash(mn->data)); + for (const auto& c : mn->pattern) { + hash = Combine(hash, PatternHash(c->lhs)); + hash = Combine(hash, ExprHash(c->rhs)); + } + return hash; + } + + size_t VisitExpr_(const ConstructorNode* cn) final { + size_t hash = std::hash()(ConstructorNode::_type_key); + hash = Combine(hash, std::hash()(cn->name_hint)); + return hash; + } + + size_t VisitType_(const TypeCallNode* tcn) final { + size_t hash = std::hash()(TypeCallNode::_type_key); + hash = Combine(hash, TypeHash(tcn->func)); + for (const auto& t : tcn->args) { + hash = Combine(hash, TypeHash(t)); + } + return hash; + } + + size_t VisitType_(const TypeDataNode* tdn) final { + size_t hash = std::hash()(TypeDataNode::_type_key); + hash = Combine(hash, TypeHash(tdn->header)); + for (const auto& tv : tdn->tv) { + hash = Combine(hash, TypeHash(tv)); + } + for (const auto& cn : tdn->constructors) { + hash = Combine(hash, ExprHash(cn)); + } + return hash; + } + + size_t VisitType_(const GlobalTypeVarNode* tvn) final { + return BindVar(GetRef(tvn)); + } + + size_t PatternHash(const Pattern& p) { + return VisitPattern(p); + } + + size_t VisitPattern_(const PatternConstructorNode* pcn) final { + size_t hash = std::hash()(PatternConstructorNode::_type_key); + hash = Combine(hash, ExprHash(pcn->con)); + for (const auto& p : pcn->pat) { + hash = Combine(hash, PatternHash(p)); + } + return hash; + } + + size_t VisitPattern_(const PatternVarNode* pvn) final { + size_t hash = std::hash()(PatternVarNode::_type_key); + hash = Combine(hash, ExprHash(pvn->var)); + return hash; + } + + size_t VisitPattern_(const PatternWildcardNode* pwn) final { + size_t hash = std::hash()(PatternWildcardNode::_type_key); + return hash; + } + private: // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 9ba5efecec80..4308fe785100 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -13,18 +13,28 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Module ModuleNode::make(tvm::Map global_funcs) { +Module ModuleNode::make(tvm::Map global_funcs, + tvm::Map global_type_defs) { auto n = make_node(); n->functions = std::move(global_funcs); + n->type_definitions = std::move(global_type_defs); for (const auto& kv : n->functions) { - // set gloval var map + // set global var map CHECK(!n->global_var_map_.count(kv.first->name_hint)) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } n->entry_func = GlobalVarNode::make("main"); + + for (const auto& kv : n->type_definitions) { + // set global typevar map + CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) + << "Duplicate global type definition name " << kv.first->var->name_hint; + n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + } + return Module(n); } @@ -51,6 +61,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } +GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { + auto it = global_type_var_map_.find(name); + CHECK(it != global_type_var_map_.end()) + << "Cannot find global type var " << name << " in the Module"; + return (*it).second; +} + void ModuleNode::Add(const GlobalVar& var, const Function& func, bool update) { @@ -69,6 +86,19 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { + // kind checker is broken, not checking them rn. + // TODO(slyubomirsky, MarisaKirisame): fix the kind checker. + this->type_definitions.Set(var, type); + // set global type var map + CHECK(!global_type_var_map_.count(var->var->name_hint)) + << "Duplicate global type definition name " << var->var->name_hint; + global_type_var_map_.Set(var->var->name_hint, var); + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = i; + } + } + void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } @@ -92,6 +122,18 @@ Function ModuleNode::Lookup(const std::string& name) { return this->Lookup(id); } +TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { + auto it = type_definitions.find(var); + CHECK(it != type_definitions.end()) + << "There is no definition of " << var->var->name_hint; + return (*it).second; +} + +TypeData ModuleNode::LookupDef(const std::string& name) { + GlobalTypeVar id = this->GetGlobalTypeVar(name); + return this->LookupDef(id); +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -117,21 +159,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ModuleNode::make(args[0]); + *ret = ModuleNode::make(args[0], args[1]); }); -TVM_REGISTER_API("relay._module.Module_Add") +TVM_REGISTER_API("relay._make.Module_Add") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; mod->Add(args[1], args[2], args[3]); }); +TVM_REGISTER_API("relay._module.Module_AddDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + mod->AddDef(args[1], args[2]); + }); + TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; *ret = mod->GetGlobalVar(args[1]); }); +TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + *ret = mod->GetGlobalTypeVar(args[1]); + }); + TVM_REGISTER_API("relay._module.Module_Lookup") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; @@ -143,8 +197,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; std::string var_name = args[1]; - auto var = mod->GetGlobalVar(var_name); - *ret = mod->Lookup(var); + *ret = mod->Lookup(var_name); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + GlobalTypeVar var = args[1]; + *ret = mod->LookupDef(var); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef_str") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + std::string var_name = args[1]; + *ret = mod->LookupDef(var_name); }); TVM_REGISTER_API("relay._module.Module_Update") diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc new file mode 100644 index 000000000000..71002058fe49 --- /dev/null +++ b/src/relay/ir/pattern_functor.cc @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pattern_functor.cc + * \brief Implementations of visitors and mutators for ADT patterns. + */ + +#include + +namespace tvm { +namespace relay { + +Pattern PatternMutator::Mutate(const Pattern& pat) { + return (*this)(pat); +} + +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { + return GetRef(op); +} + +Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { + return PatternVarNode::make(VisitVar(op->var)); +} + +Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { + std::vector pat; + for (const auto& p : op->pat) { + pat.push_back(VisitPattern(p)); + } + return PatternConstructorNode::make(VisitConstructor(op->con), pat); +} + +Type PatternMutator::VisitType(const Type& t) { + return t; +} + +Var PatternMutator::VisitVar(const Var& v) { + if (var_map_.count(v) == 0) { + var_map_.insert(std::pair(v, + VarNode::make(v->name_hint(), + VisitType(v->type_annotation)))); + } + return var_map_.at(v); +} + +Constructor PatternMutator::VisitConstructor(const Constructor& v) { + return v; +} + +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } + +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { + VisitVar(op->var); +} + +void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { + VisitConstructor(op->con); + for (const auto& p : op->pat) { + VisitPattern(p); + } +} + +void PatternVisitor::VisitType(const Type& t) { } + +void PatternVisitor::VisitVar(const Var& v) { + VisitType(v->type_annotation); +} + +void PatternVisitor::VisitConstructor(const Constructor& c) { + for (const auto& inp : c->inp) { + VisitType(inp); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 05179d584d84..3cca40e605e6 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -23,6 +24,12 @@ struct TextValue { TextValue() {} // constructor explicit TextValue(std::string name) : name(name) {} + TextValue operator+(const TextValue& rhs) const { + return TextValue(name + rhs.name); + } + TextValue operator+(const std::string& str) const { + return TextValue(name + str); + } }; // operator overloading @@ -128,6 +135,7 @@ class TextMetaDataContext { class TextPrinter : public ExprFunctor, + public PatternFunctor, public TypeFunctor, // NOLINT(*) public AttrFunctor { // NOLINT(*) public: @@ -213,6 +221,9 @@ class TextPrinter : memo_[expr] = val; return val; } + TextValue GetValue(const Pattern& p) { + return this->VisitPattern(p); + } //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -391,6 +402,36 @@ class TextPrinter : return id; } + TextValue VisitExpr_(const MatchNode* op) final { + TextValue data = GetValue(op->data); + this->PrintIndent(); + TextValue id = this->AllocTempVar(); + stream_ << id << " = " << "Match " << data << " with"; + this->PrintEndInst("\n"); + for (const auto& c : op->pattern) { + this->PrintIndent(); + stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); + this->PrintEndInst("\n"); + } + return id; + } + + TextValue VisitPattern_(const PatternConstructorNode* p) final { + TextValue ret(p->con->name_hint + "("); + for (const Pattern& pat : p->pat) { + ret = ret + " " + GetValue(pat); + } + return ret + ")"; + } + + TextValue VisitPattern_(const PatternVarNode* pv) final { + return GetValue(pv->var); + } + + TextValue VisitExpr_(const ConstructorNode* n) final { + return TextValue(n->name_hint); + } + /*! * \brief Print the type to os * \param type The type to be printed. @@ -437,6 +478,18 @@ class TextPrinter : VisitTypeDefault_(node, os); } + void VisitType_(const TypeCallNode* node, std::ostream& os) final { + os << node->func << "(" << node->args << ")"; + } + + void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + + void VisitType_(const TypeDataNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef(node)); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index e829d8abd63c..25b7beb5356a 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { +TypeVar TypeVarNode::make(std::string name, Kind kind) { NodePtr n = make_node(); n->var = tvm::Var(name); n->kind = std::move(kind); @@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[1]; *ret = - TypeVarNode::make(args[0], static_cast(kind)); + TypeVarNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->kind << ")"; }); -IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) { +GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { + NodePtr n = make_node(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return GlobalTypeVar(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); + +TVM_REGISTER_API("relay._make.GlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + int kind = args[1]; + *ret = GlobalTypeVarNode::make(args[0], static_cast(kind)); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const GlobalTypeVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " + << node->kind << ")"; +}); + +TypeCall TypeCallNode::make(Type func, tvm::Array args) { + NodePtr n = make_node(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_NODE_TYPE(TypeCallNode); + +TVM_REGISTER_API("relay._make.TypeCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeCallNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeCallNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeCallNode(" << node->func << ", " + << node->args << ")"; +}); + +IncompleteType IncompleteTypeNode::make(Kind kind) { auto n = make_node(); n->kind = std::move(kind); return IncompleteType(n); @@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_API("relay._make.IncompleteType") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); + *ret = IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 100c633a2997..61978d5ccba0 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -48,6 +48,23 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { +} + +void TypeVisitor::VisitType_(const TypeCallNode* op) { + this->VisitType(op->func); + for (const Type& t : op->args) { + this->VisitType(t); + } +} + +void TypeVisitor::VisitType_(const TypeDataNode* op) { + this->VisitType(op->header); + for (const auto& v : op->tv) { + this->VisitType(v); + } + // TODO(slyubomirsky, MarisaKirisame): visit constructors +} // Type Mutator. Array TypeMutator::MutateArray(Array arr) { @@ -139,6 +156,22 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { } } +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { + return GetRef(op); +} + +Type TypeMutator::VisitType_(const TypeCallNode* op) { + std::vector args; + for (const auto& a : op->args) { + args.push_back(VisitType(a)); + } + return TypeCallNode::make(VisitType(op->func), args); +} + +Type TypeMutator::VisitType_(const TypeDataNode* op) { + return GetRef(op); +} + // Implements bind. class TypeBinder : public TypeMutator { public: diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 1be55e78eee6..36f77967c253 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -69,6 +70,10 @@ class TypeFunctor { virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); throw; // unreachable, written to stop compiler warning @@ -87,6 +92,9 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode); return vtable; } }; @@ -103,6 +111,9 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; void VisitType_(const RefTypeNode* op) override; + void VisitType_(const GlobalTypeVarNode* op) override; + void VisitType_(const TypeCallNode* op) override; + void VisitType_(const TypeDataNode* op) override; }; // Mutator that transform a type to another one. @@ -115,6 +126,9 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; Type VisitType_(const RefTypeNode* op) override; + Type VisitType_(const GlobalTypeVarNode* op) override; + Type VisitType_(const TypeCallNode* op) override; + Type VisitType_(const TypeDataNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index a6298ba448f3..0012a117a56c 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -296,6 +296,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const MatchNode* op) final { + this->Update(op->data, nullptr, kOpaque); + for (const Clause& c : op->pattern) { + this->Update(c->rhs, nullptr, kOpaque); + } + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create( diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 200f5385a37a..22f2f2e731c0 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -20,7 +20,6 @@ namespace tvm { namespace relay { using namespace tvm::runtime; -using Kind = TypeVarNode::Kind; struct KindChecker : TypeVisitor { bool valid; @@ -113,7 +112,7 @@ bool KindCheck(const Type& t, const Module& mod) { TVM_REGISTER_API("relay._ir_pass.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - *ret = KindCheck(args[0], ModuleNode::make({})); + *ret = KindCheck(args[0], ModuleNode::make({}, {})); } else { *ret = KindCheck(args[0], args[1]); } diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 2fecc8ba3727..3afbcba96ae6 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -62,7 +62,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr); + return Push(IncompleteTypeNode::make(Kind::kType), expr); } /*! diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 10ba3b127bbf..5360550870e3 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -22,7 +22,9 @@ #include #include +#include #include +#include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -79,7 +81,8 @@ struct ResolvedTypeInfo { // - Solve the constraints (solver_.Solve) // - Recreate expression with the resolved checked_type (Resolver.VisitExpr) // -class TypeInferencer : private ExprFunctor { +class TypeInferencer : private ExprFunctor, + private PatternFunctor { public: // constructors @@ -184,7 +187,7 @@ class TypeInferencer : private ExprFunctor { if (op->type_annotation.defined()) { return op->type_annotation; } else { - return IncompleteTypeNode::make(TypeVarNode::kType); + return IncompleteTypeNode::make(Kind::kType); } } @@ -219,7 +222,7 @@ class TypeInferencer : private ExprFunctor { EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_); } Type tuple_type = GetType(op->tuple); - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( @@ -227,6 +230,45 @@ class TypeInferencer : private ExprFunctor { return rtype; } + void VisitPattern_(const PatternConstructorNode* con, const Type& t) { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << con->con->name_hint; + TypeData td = mod_->type_definitions.at(con->con->belong_to); + auto* tc = t.as(); + CHECK(tc) << "must be type call"; + CHECK_EQ(td->header, tc->func); + CHECK(td->tv.size() == tc->args.size()) << "both side must be equal"; + std::unordered_map type_var_map_; + for (size_t i = 0; i < td->tv.size(); ++i) { + type_var_map_[td->tv[i]] = tc->args[i]; + } + CHECK(con->con->inp.size() == con->pat.size()) << "not enough pattern"; + for (size_t i = 0; i < con->con->inp.size(); ++i) { + VisitPattern(con->pat[i], Bind(con->con->inp[i], type_var_map_)); + } + } + + void VisitPattern_(const PatternVarNode* pv, const Type& t) { + type_map_[pv->var] = ResolvedTypeInfo(t, {}); + } + + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + + Type VisitExpr_(const MatchNode* op) final { + Type dtype = GetType(op->data); + for (const auto& c : op->pattern) { + VisitPattern(c->lhs, dtype); + } + Type rtype = IncompleteTypeNode::make(Kind::kType); + for (const auto& c : op->pattern) { + rtype = this->Unify(rtype, + GetType(c->rhs), + op->span); + } + return rtype; + } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } @@ -276,7 +318,7 @@ class TypeInferencer : private ExprFunctor { for (size_t i = 0; i < op->type_params.size(); ++i) { if (!op->type_params[i].same_as(rel->args[i])) return Type(); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( @@ -302,7 +344,7 @@ class TypeInferencer : private ExprFunctor { // This is a temporary work around to check recursive functions whose // return type is not yet known. if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + ret_type = IncompleteTypeNode::make(Kind::kType); } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, @@ -448,9 +490,21 @@ class TypeInferencer : private ExprFunctor { this->Unify(GetType(op->value), it, GetRef(op)); return TupleTypeNode::make({}); } + + Type VisitExpr_(const ConstructorNode* c) final { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << c->name_hint; + TypeData td = mod_->type_definitions.at(c->belong_to); + std::vector types; + for (const auto & t : td->tv) { + types.push_back(t); + } + return FuncTypeNode::make(c->inp, TypeCallNode::make(c->belong_to, types), td->tv, {}); + } }; -class TypeInferencer::Resolver : public ExprMutator { +class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) @@ -458,7 +512,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - return AttachCheckedType(op); + return VisitVar(GetRef(op)); } Expr VisitExpr_(const ConstantNode* op) final { @@ -509,6 +563,25 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const ConstructorNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const MatchNode* op) final { + return AttachCheckedType(op); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Var VisitVar(const Var& v) final { + if (vmap_.count(v) == 0) { + vmap_[v] = GetRef(AttachCheckedType(v.as()).as()); + } + return vmap_.at(v); + } + // attach checked type to the mutated node. template Expr AttachCheckedType(const T* op) { @@ -601,6 +674,7 @@ class TypeInferencer::Resolver : public ExprMutator { } private: + std::unordered_map vmap_; const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation @@ -625,6 +699,18 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +struct AllCheckTypePopulated : ExprVisitor { + void VisitExpr(const Expr& e) { + if (e.as()) { return; } + if (e.as()) { return; } + CHECK(e->checked_type_.defined()) << "Expression: " << e; + return ExprVisitor::VisitExpr(e); + } +}; + +void EnsureCheckedType(const Expr& e) { + AllCheckTypePopulated().VisitExpr(e); +} Expr InferType(const Expr& expr, const Module& mod_ref) { if (!mod_ref.defined()) { @@ -645,6 +731,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); + EnsureCheckedType(e); return e; } } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 403863c1d757..499046321953 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -7,6 +7,7 @@ */ #include #include +#include #include "../ir/type_functor.h" namespace tvm { diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py new file mode 100644 index 000000000000..2f2d1f98906e --- /dev/null +++ b/tests/python/relay/test_adt.py @@ -0,0 +1,138 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.backend.interpreter import Value, TupleValue, ConValue +from tvm.relay import testing, create_executor +from tvm.relay.prelude import Prelude + +mod = relay.Module() +p = Prelude(mod) +ctx = tvm.context("llvm", 0) +intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + +z = p.z +s = p.s +nat = p.nat +double = p.double +add = p.add + +nil = p.nil +cons = p.cons +l = p.l +length = p.length +map = p.map +foldl = p.foldl +foldr = p.foldr +sum = p.sum + +tree = p.tree +rose = p.rose +tmap = p.tmap +size = p.size + +# this is an example of using the adt value in python side +def count(n): + assert isinstance(n, ConValue) + if n.con.name_hint == 's': + return 1 + count(n.fields[0]) + else: + assert n.con.name_hint == 'z' + return 0 + +# this is an example of creating the adt value in python side +def make_nat(n): + if n != 0: + return ConValue(s, [make_nat(n - 1)], []) + else: + return ConValue(z, [], []) + + +def test_nat_value(): + assert count(make_nat(10)) == 10 + + +def test_nat_constructor(): + assert relay.ir_pass.infer_type(z(), mod).checked_type == nat() + assert relay.ir_pass.infer_type(s, mod).checked_type == relay.FuncType([nat()], nat()) + assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat() + + +def test_double(): + assert mod[double].checked_type == relay.FuncType([nat()], nat()) + + +def test_add(): + assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) + res = intrp.evaluate(add(s(z()), s(z()))) + assert count(res) == 2 + + +def test_list_constructor(): + a = relay.TypeVar("a") + assert relay.ir_pass.infer_type(nil, mod).checked_type == relay.FuncType([], l(a), [a]) + assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) + + +def test_length(): + a = relay.TypeVar("a") + assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + + +def test_map(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[map].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b]) + assert lhs == rhs + + +def test_foldl(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldl].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) + assert lhs == rhs + + +def test_foldr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldr].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) + assert lhs == rhs + + +def test_sum(): + assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) + + +def test_tmap(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + # cannot infer return type of tmap! + lhs = mod[tmap].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) + # print(lhs) + # print(rhs) + # assert lhs == rhs + # this is broken, need some way to add type annotation + +def test_size(): + a = relay.TypeVar("a") + lhs = mod[size].checked_type + rhs = relay.FuncType([tree(a)], nat(), [a]) + assert lhs == rhs + + +if __name__ == "__main__": + test_nat_constructor() + test_double() + test_add() + test_list_constructor() + test_length() + test_map() + test_foldl() + test_foldr() + test_sum() + test_tmap() + test_size() diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py new file mode 100644 index 000000000000..8f556c0af67e --- /dev/null +++ b/tests/python/relay/test_typecall.py @@ -0,0 +1,25 @@ +from tvm import relay +from tvm.relay.ir_pass import infer_type + +def test_dup_type(): + a = relay.TypeVar("a") + av = relay.Var("av", a) + make_id = relay.Function([av], relay.Tuple([av, av]), None, [a]) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t]) + + +def test_id_type(): + mod = relay.Module() + id_type = relay.TypeVar("id") + a = relay.TypeVar("a") + make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) + + +if __name__ == "__main__": + test_dup_type() + test_id_type() From 485889d09ce0b4dd8332221382550f86e95b2d27 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 15:01:05 -0800 Subject: [PATCH 02/61] Add doc string for tag field --- include/tvm/relay/adt.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index c0fcc7629c51..7ac29fee2896 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -94,6 +94,7 @@ class ConstructorNode : public ExprNode { tvm::Array inp; /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; + /*! \brief Index in the table of constructors (set when the type is registered). */ mutable int tag = -1; ConstructorNode() {} From 54e5196323881dc04273be318b9a37e0f016a8ab Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 15:22:14 -0800 Subject: [PATCH 03/61] Visit constructors in TypeVisitor for TypeData --- src/relay/ir/type_functor.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 61978d5ccba0..acb88d0a349c 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -63,7 +63,13 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { for (const auto& v : op->tv) { this->VisitType(v); } - // TODO(slyubomirsky, MarisaKirisame): visit constructors + + for (const auto& c : op->constructors) { + this->VisitType(c->belong_to); + for (const auto& t : c->inp) { + this->VisitType(t); + } + } } // Type Mutator. From 671f615ccebe17636adc9cdfdf2e968cb314b3ac Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 15:30:39 -0800 Subject: [PATCH 04/61] Add to description of type call --- include/tvm/relay/type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 600075bcb8cc..ad722ce85e41 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -187,7 +187,7 @@ class TypeCall; class TypeCallNode : public TypeNode { public: /*! - * \brief The type-level function. + * \brief The type-level function (ADT that takes type params). */ Type func; /*! \brief The arguments. */ From f9e48e805eb2a0e84babae2da5c01351f825e78b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 16:36:11 -0800 Subject: [PATCH 05/61] Add type call to type solving and unification --- src/relay/pass/type_infer.cc | 10 +++++----- src/relay/pass/type_solver.cc | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5360550870e3..274ac429cc61 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -151,12 +151,12 @@ class TypeInferencer : private ExprFunctor, } for (auto type_param : ft->type_params) { - instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + instantiation_map_.Set(type_param, IncompleteTypeNode::make(Kind::kType)); } Type ret_type = ft->ret_type; if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + ret_type = IncompleteTypeNode::make(Kind::kType); } auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); @@ -277,7 +277,7 @@ class TypeInferencer : private ExprFunctor, // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_map_[op->var].checked_type = IncompleteTypeNode::make(Kind::kType); } Type vtype = GetType(let->value); @@ -380,7 +380,7 @@ class TypeInferencer : private ExprFunctor, // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { - Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type ret_type = IncompleteTypeNode::make(Kind::kType); Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); @@ -389,7 +389,7 @@ class TypeInferencer : private ExprFunctor, Array type_args = call->type_args; if (type_args.size() == 0) { for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + type_args.push_back(IncompleteTypeNode::make(Kind::kType)); } } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index fcd39e791339..21988f876991 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -189,6 +189,20 @@ class TypeSolver::Unifier : public TypeFunctor { return RefTypeNode::make(Unify(op->value, rtn->value)); } + Type VisitType_(const TypeCallNode* op, const Type& tn) { + const auto* tcn = tn.as(); + if (!tcn || tcn->args.size() != op->args.size()) { + return Type(); + } + + Type func = Unify(op->func, tcn->func); + tvm::Array args; + for (size_t i = 0; i < op->args.size(); i++) { + args.push_back(Unify(op->args[i], tcn->args[i])); + } + return TypeCallNode::make(func, args); + } + private: TypeSolver* solver_; }; @@ -266,6 +280,16 @@ class TypeSolver::Propagator : public TypeFunctor { } } + void VisitType_(const TypeCallNode* op) override { + TypeCall tc = GetRef(op); + UpdateRelSet(tc); + + Propagate(tc->func); + for (auto arg : tc->args) { + Propagate(arg); + } + } + private: TypeSolver* solver_; const std::unordered_set* rels_; From 94748c19950f18b574f14074b3a595b4fac78473 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 17:10:06 -0800 Subject: [PATCH 06/61] Make type mutator for typecall consistent with others (only create new node if there's a change) --- src/relay/ir/type_functor.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index acb88d0a349c..49fd905e2297 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -167,11 +167,13 @@ Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { } Type TypeMutator::VisitType_(const TypeCallNode* op) { - std::vector args; - for (const auto& a : op->args) { - args.push_back(VisitType(a)); + Type new_func = VisitType(op->func); + Array new_args = MutateArray(op->args); + if (new_args.same_as(op->args) && new_func.same_as(op->func)) { + return GetRef(op); + } else { + return TypeCallNode::make(new_func, new_args); } - return TypeCallNode::make(VisitType(op->func), args); } Type TypeMutator::VisitType_(const TypeDataNode* op) { From 7b1126f6e028bceffa4132359546c443a61725dd Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 18:45:51 -0800 Subject: [PATCH 07/61] Ensure kindchecking can handle type calls and typedata --- src/relay/ir/module.cc | 15 +++++---- src/relay/pass/kind_check.cc | 59 ++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 4308fe785100..a64d145ccdaf 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -87,18 +87,21 @@ void ModuleNode::Add(const GlobalVar& var, } void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { - // kind checker is broken, not checking them rn. - // TODO(slyubomirsky, MarisaKirisame): fix the kind checker. this->type_definitions.Set(var, type); // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) << "Duplicate global type definition name " << var->var->name_hint; - global_type_var_map_.Set(var->var->name_hint, var); - for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = i; - } + global_type_var_map_.Set(var->var->name_hint, var); + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = i; } + // need to kind check at the end because the check can look up + // a definition potentially + CHECK(KindCheck(type, GetRef(this))) + << "Kind-checking fails for the type data given: " << type; +} + void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 22f2f2e731c0..57fa5d529459 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -23,8 +23,9 @@ using namespace tvm::runtime; struct KindChecker : TypeVisitor { bool valid; + const Module& mod; - KindChecker() : valid(true) {} + explicit KindChecker(const Module& mod) : valid(true), mod(mod) {} // checks if t is an incomplete node of kind k or a type param of kind k bool MatchKind(const Type& t, Kind k) { @@ -36,6 +37,10 @@ struct KindChecker : TypeVisitor { return tp->kind == k; } + if (const GlobalTypeVarNode* gtp = t.as()) { + return gtp->kind == k; + } + return false; } @@ -44,7 +49,8 @@ struct KindChecker : TypeVisitor { return true; } - return t.as_derived() || t.as() || t.as(); + return t.as_derived() || t.as() || t.as() + || t.as(); } void VisitType_(const TupleTypeNode* op) override { @@ -98,6 +104,53 @@ struct KindChecker : TypeVisitor { } } + void VisitType_(const TypeCallNode* op) override { + // type call func should be a global type var, args should be type + const auto* gtv = op->func.as(); + valid = valid && gtv != nullptr && IsTypeKind(op->func); + if (!valid) { + return; + } + for (const Type& t : op->args) { + this->VisitType(t); + valid = valid && IsTypeKind(t); + if (!valid) { + return; + } + } + + // finally we need to check the module to check the number of type params + auto var = GetRef(gtv); + auto data = mod->LookupDef(var); + valid = valid && data->tv.size() == op->args.size(); + } + + void VisitType_(const TypeDataNode* op) override { + // Constructors can reference the header var, but no other GlobalTypeVars. + // In theory, a TypeData could be nested, so the header scope + // should be tracked recursively, but it is unclear that we need + // to support it. + valid = valid && op->header->kind == Kind::kType; + for (const auto& var : op->tv) { + valid = valid && IsTypeKind(var); + if (!valid) { + return; + } + } + for (const auto& con : op->constructors) { + valid = valid && con->belong_to.same_as(op->header); + for (const Type& t : con->inp) { + valid = valid && IsTypeKind(t); + if (const auto* gtv = t.as()) { + valid = valid && GetRef(gtv).same_as(op->header); + } + if (!valid) { + return; + } + } + } + } + bool Check(const Type& t) { this->VisitType(t); return valid; @@ -105,7 +158,7 @@ struct KindChecker : TypeVisitor { }; bool KindCheck(const Type& t, const Module& mod) { - KindChecker kc; + KindChecker kc(mod); return kc.Check(t); } From f2d36d6dcf97c8e4e1dd5b2157318aa2d926e8f4 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 15:58:13 -0800 Subject: [PATCH 08/61] Fix bad nesting in module constructor --- python/tvm/relay/module.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 5812d9a2481b..01ac84212aa3 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -33,17 +33,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") mapped_funcs[k] = v functions = mapped_funcs - if type_definitions is None: - type_definitions = {} - elif isinstance(type_definitions, dict): - mapped_type_defs = {} - for k, v in type_definitions.items(): - if isinstance(k, _base.string_types): - k = _ty.GlobalTypeVar(k) - if not isinstance(k, _ty.GlobalTypeVar): - raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") - type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) + if type_definitions is None: + type_definitions = {} + elif isinstance(type_definitions, dict): + mapped_type_defs = {} + for k, v in type_definitions.items(): + if isinstance(k, _base.string_types): + k = _ty.GlobalTypeVar(k) + if not isinstance(k, _ty.GlobalTypeVar): + raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") + type_definitions = mapped_type_defs + self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) def __setitem__(self, var, val): From 7dc26ef10dafa74dec2bcaa80509ac5244af2492 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 16:02:19 -0800 Subject: [PATCH 09/61] Correctly construct call in typecall test --- tests/python/relay/test_typecall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index 8f556c0af67e..d36fbb008940 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -17,7 +17,7 @@ def test_id_type(): make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) t = relay.scalar_type("float32") b = relay.Var("b", t) - assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) + assert relay.ir_pass.infer_type(relay.Call(make_id, [b]), mod).checked_type == id_type(t) if __name__ == "__main__": From 0c949c7e02f47d3783307cdb29604d542e907c3f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 16:33:36 -0800 Subject: [PATCH 10/61] Add call override for ordinary vars (do we want this?) --- python/tvm/relay/adt.py | 17 ++++++++++++++++- python/tvm/relay/expr.py | 14 ++++++++++++++ tests/python/relay/test_typecall.py | 2 +- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 7e25ec16fb52..d3fd94359d8d 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -3,7 +3,7 @@ from .base import RelayNode, register_relay_node, NodeBase from . import _make from .ty import Type -from .expr import Expr +from .expr import Expr, Call class Pattern(RelayNode): @@ -96,6 +96,21 @@ def __init__(self, name_hint, inp, belong_to): """ self.__init_handle_by_constructor__(_make.Constructor, name_hint, inp, belong_to) + def __call__(self, *args): + """Call the constructor. + + Parameters + ---------- + args: List[relay.Expr] + The arguments to the constructor. + + Returns + ------- + call: relay.Call + A call to the constructor. + """ + return Call(self, args) + @register_relay_node class TypeData(Type): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 71b89d0b4777..9257bad7dd58 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -172,6 +172,20 @@ def name_hint(self): name = self.vid.name_hint return name + def __call__(self, *args): + """Call the variable (if it represents a function). + + Parameters + ---------- + args: List[relay.Expr] + The arguments to the call. + + Returns + ------- + call: Call + A call taking the variable as a function. + """ + return Call(self, args) @register_relay_node class GlobalVar(Expr): diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index d36fbb008940..8f556c0af67e 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -17,7 +17,7 @@ def test_id_type(): make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) t = relay.scalar_type("float32") b = relay.Var("b", t) - assert relay.ir_pass.infer_type(relay.Call(make_id, [b]), mod).checked_type == id_type(t) + assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) if __name__ == "__main__": From 9443bdc41b69b87e0b3d8db6b31003dc120447b2 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 17:31:16 -0800 Subject: [PATCH 11/61] Remove generalization hack from type inference because it was breaking ADT constructors --- src/relay/pass/type_infer.cc | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 274ac429cc61..ce057016e4d9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -110,10 +110,6 @@ class TypeInferencer : private ExprFunctor, // type inferencer will populate it up std::unordered_map type_map_; - // used to ensure we don't have free type vars hanging around - // (a temporary measure until we have proper generalization implemented) - Map instantiation_map_; - // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -138,31 +134,6 @@ class TypeInferencer : private ExprFunctor, } } - // Substitutes every type var in t with a corresponding incomplete type. - // This is a temporary measure to ensure type vars behave until - // generalization is properly implemented. - Type Instantiate(const Type &t) { - if (!t.defined()) { - return t; - } - auto* ft = t.as(); - if (ft == nullptr) { - return Bind(t, instantiation_map_); - } - - for (auto type_param : ft->type_params) { - instantiation_map_.Set(type_param, IncompleteTypeNode::make(Kind::kType)); - } - - Type ret_type = ft->ret_type; - if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(Kind::kType); - } - - auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); - return Bind(strip_tvs, instantiation_map_); - } - // Lazily get type for expr // expression, we will populate it now, and return the result. Type GetType(const Expr &expr) { @@ -170,7 +141,7 @@ class TypeInferencer : private ExprFunctor, if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = Instantiate(this->VisitExpr(expr)); + Type ret = this->VisitExpr(expr); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; From 33fc344355a28a5289742a44301f5fd409e249dd Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 17:58:20 -0800 Subject: [PATCH 12/61] Check that there are no free type vars in exprs after inferring type --- src/relay/pass/type_infer.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ce057016e4d9..d989e50203c4 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -702,6 +702,9 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); + auto free_tvars = FreeTypeVars(e); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); return e; } @@ -716,6 +719,9 @@ Function InferType(const Function& func, Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); + auto free_tvars = FreeTypeVars(func_ret); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << func << ": " << free_tvars; return Downcast(func_ret); } From c208071ce2f298f0badf2ca22b0ae07308da102f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 18:42:39 -0800 Subject: [PATCH 13/61] Free var checks need module because of ADT constructors --- include/tvm/relay/pass.h | 18 ++++++++---- python/tvm/relay/ir_pass.py | 22 +++++++++++---- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/util.cc | 54 +++++++++++++++++++++++------------- 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1558e65a6b36..d55386885d28 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -144,10 +144,11 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); * type in the context. * * \param expr the expression. + * \param mod the module. * * \return List of free vars, in the PostDFS order visited by expr. */ -TVM_DLL tvm::Array FreeTypeVars(const Expr& expr); +TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); /*! \brief Get free TypeVars from type t. * @@ -155,10 +156,11 @@ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr); * type in the context. * * \param t the type. + * \param mod the module. * * \return List of free type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array FreeTypeVars(const Type& t); +TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); /*! \brief Get all bound type variables from expression expr. * @@ -166,10 +168,11 @@ TVM_DLL tvm::Array FreeTypeVars(const Type& t); * They only have meaning inside that expr, and can only be used in it. * * \param expr the expression. + * \param mod the module. * * \return List of bound type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array BoundTypeVars(const Expr& expr); +TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); /*! \brief Get all bound type variables from type t. * @@ -177,26 +180,29 @@ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr); * They only have meaning inside that type, and can only be used in it. * * \param t the type + * \param mod the module. * * \return List of bound type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array BoundTypeVars(const Type& t); +TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); /*! \brief Get all type variables in expression expr. * * \param expr the expression. + * \param mod the module. * * \return List of type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllTypeVars(const Expr& expr); +TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); /*! \brief Get all type variables in type t. * * \param t the type. + * \param mod the module. * * \return List of type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array AllTypeVars(const Type& t); +TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); /*! \brief Remove expressions which does not effect the program result. * diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index b27f030e459a..e274926f0b7a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -9,6 +9,7 @@ from . import _make from .expr import Expr from .ty import Type +from .module import Module def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, @@ -190,52 +191,61 @@ def all_vars(expr): return _ir_pass.all_vars(expr) -def free_type_vars(expr): +def free_type_vars(expr, mod=None): """Get free type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of free type variables in post-DFS order """ - return _ir_pass.free_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.free_type_vars(expr, use_mod) -def bound_type_vars(expr): +def bound_type_vars(expr, mod=None): """Get bound type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of bound type variables in post-DFS order """ - return _ir_pass.bound_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.bound_type_vars(expr, use_mod) -def all_type_vars(expr): +def all_type_vars(expr, mod=None): """Get all type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of all type variables in post-DFS order """ - return _ir_pass.all_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.all_type_vars(expr, use_mod) def simplify_inference(expr): diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d989e50203c4..d89ba234660c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -719,7 +719,7 @@ Function InferType(const Function& func, Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); - auto free_tvars = FreeTypeVars(func_ret); + auto free_tvars = FreeTypeVars(func_ret, mod); CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << func << ": " << free_tvars; return Downcast(func_ret); diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 499046321953..0ed13b40105b 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -30,7 +30,7 @@ class TypeVarTVisitor : public TypeVisitor { TypeVarTVisitor( InsertionSet* type_vars, InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); @@ -52,6 +52,8 @@ class TypeVarTVisitor : public TypeVisitor { class TypeVarEVisitor : private ExprVisitor { public: + explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {} + Array CollectFree() { Array ret; for (const auto& v : type_vars_.data) { @@ -116,6 +118,16 @@ class TypeVarEVisitor : private ExprVisitor { ExprVisitor::VisitExpr_(f); } + void VisitExpr_(const ConstructorNode* cn) final { + // for constructors, type vars will be bound in the module + auto data = mod_->LookupDef(cn->belong_to); + for (const auto& tv : data->tv) { + type_vars_.Insert(tv); + bound_type_vars_.Insert(tv); + } + ExprVisitor::VisitExpr_(cn); + } + void VisitType(const Type& t) final { TypeVarTVisitor(&type_vars_, &bound_type_vars_) .VisitType(t); @@ -124,6 +136,7 @@ class TypeVarEVisitor : private ExprVisitor { private: InsertionSet type_vars_; InsertionSet bound_type_vars_; + const Module& mod_; }; class VarVisitor : protected ExprVisitor { @@ -184,28 +197,28 @@ class VarVisitor : protected ExprVisitor { InsertionSet bound_vars_; }; -tvm::Array FreeTypeVars(const Expr& expr) { - return TypeVarEVisitor().Free(expr); +tvm::Array FreeTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).Free(expr); } -tvm::Array FreeTypeVars(const Type& type) { - return TypeVarEVisitor().Free(type); +tvm::Array FreeTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).Free(type); } -tvm::Array BoundTypeVars(const Expr& expr) { - return TypeVarEVisitor().Bound(expr); +tvm::Array BoundTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).Bound(expr); } -tvm::Array BoundTypeVars(const Type& type) { - return TypeVarEVisitor().Bound(type); +tvm::Array BoundTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).Bound(type); } -tvm::Array AllTypeVars(const Expr& expr) { - return TypeVarEVisitor().All(expr); +tvm::Array AllTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).All(expr); } -tvm::Array AllTypeVars(const Type& type) { - return TypeVarEVisitor().All(type); +tvm::Array AllTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).All(type); } tvm::Array FreeVars(const Expr& expr) { @@ -238,30 +251,33 @@ TVM_REGISTER_API("relay._ir_pass.all_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = FreeTypeVars(Downcast(x)); + *ret = FreeTypeVars(Downcast(x), mod); } else { - *ret = FreeTypeVars(Downcast(x)); + *ret = FreeTypeVars(Downcast(x), mod); } }); TVM_REGISTER_API("relay._ir_pass.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = BoundTypeVars(Downcast(x)); + *ret = BoundTypeVars(Downcast(x), mod); } else { - *ret = BoundTypeVars(Downcast(x)); + *ret = BoundTypeVars(Downcast(x), mod); } }); TVM_REGISTER_API("relay._ir_pass.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = AllTypeVars(Downcast(x)); + *ret = AllTypeVars(Downcast(x), mod); } else { - *ret = AllTypeVars(Downcast(x)); + *ret = AllTypeVars(Downcast(x), mod); } }); From af7da113566729ade5ea1297aca0376b806127dc Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 18:43:17 -0800 Subject: [PATCH 14/61] Typecall test can't have unbound type var, make it global --- tests/python/relay/test_typecall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index 8f556c0af67e..c45e435de27c 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -12,7 +12,7 @@ def test_dup_type(): def test_id_type(): mod = relay.Module() - id_type = relay.TypeVar("id") + id_type = relay.GlobalTypeVar("id") a = relay.TypeVar("a") make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) t = relay.scalar_type("float32") From 4968b48066c166f496701dd4dce718691fb72ec9 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 18:53:03 -0800 Subject: [PATCH 15/61] Uncomment tmap test and remove comments about failing to infer ret type; those work now --- python/tvm/relay/prelude.py | 1 - tests/python/relay/test_adt.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 29fcc40b9361..35235e0f7190 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -111,4 +111,3 @@ def __init__(self, mod): rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), self.s(self.sum(self.map(Function([x], self.size(x)), z)))) mod[self.size] = Function([t], Match(t, [rose_case]), self.nat(), [a]) - # cannot infer return type here diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 2f2d1f98906e..5fb02ba86c25 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -109,13 +109,9 @@ def test_sum(): def test_tmap(): a = relay.TypeVar("a") b = relay.TypeVar("b") - # cannot infer return type of tmap! lhs = mod[tmap].checked_type rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) - # print(lhs) - # print(rhs) - # assert lhs == rhs - # this is broken, need some way to add type annotation + assert lhs == rhs def test_size(): a = relay.TypeVar("a") From 87f82b7db55b1b9e18ae26da22bdb33a76190248 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 19:03:38 -0800 Subject: [PATCH 16/61] Put in dummy visits for ADTs in graph runtime codegen to placate pylint --- python/tvm/relay/backend/graph_runtime_codegen.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index cc510b2290cf..fba4d11aaf72 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -292,6 +292,12 @@ def visit_ref_read(self, _): def visit_ref_write(self, _): raise RuntimeError("reference not supported") + def visit_constructor(self, _): + raise Exception("ADT constructor case not yet implemented") + + def visit_match(self, _): + raise Exception("match case not yet implemented") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the From e0f4c08b066dfbcefb29166951e3181aa3afdfaf Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 21 Jan 2019 19:17:53 -0800 Subject: [PATCH 17/61] Fix Relay type infer test module constructor --- tests/cpp/relay_pass_type_infer_test.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 50aed4c57338..62f1d914b510 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -13,7 +13,10 @@ TEST(Relay, SelfReference) { auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); - auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); + auto empty_module = + relay::ModuleNode::make(Map{}, + Map{}); + auto type_fx = relay::InferType(fx, empty_module); auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); CHECK(AlphaEqual(type_fx->checked_type(), expected)); From b6467405a5391f47548771e4bc47b6e73beaa325 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 13:53:52 -0800 Subject: [PATCH 18/61] Mark override for TypeCallNode in type solver --- src/relay/pass/type_solver.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 21988f876991..4749e8934b36 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -189,7 +189,7 @@ class TypeSolver::Unifier : public TypeFunctor { return RefTypeNode::make(Unify(op->value, rtn->value)); } - Type VisitType_(const TypeCallNode* op, const Type& tn) { + Type VisitType_(const TypeCallNode* op, const Type& tn) override { const auto* tcn = tn.as(); if (!tcn || tcn->args.size() != op->args.size()) { return Type(); From e11f58ec468a4fdf295695eb100ae2a93c9c6aa8 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 16:41:41 -0800 Subject: [PATCH 19/61] Ensure free vars check treats patern vars as bound --- src/relay/pass/util.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 0ed13b40105b..33efac9414fa 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -139,7 +139,7 @@ class TypeVarEVisitor : private ExprVisitor { const Module& mod_; }; -class VarVisitor : protected ExprVisitor { +class VarVisitor : protected ExprVisitor, protected PatternVisitor { public: Array Free(const Expr& expr) { this->VisitExpr(expr); @@ -192,6 +192,14 @@ class VarVisitor : protected ExprVisitor { VisitExpr(op->body); } + void VisitPattern(const Pattern& p) final { + PatternVisitor::VisitPattern(p); + } + + void VisitPattern_(const PatternVarNode* op) final { + MarkBounded(op->var); + } + private: InsertionSet vars_; InsertionSet bound_vars_; From ec56e4abdab76a937ca9dfdc10ed19126e553bbb Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 17:16:28 -0800 Subject: [PATCH 20/61] Run interpreter in more ADT test cases --- tests/python/relay/test_adt.py | 56 ++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 5fb02ba86c25..d501c7d74a62 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -46,6 +46,26 @@ def make_nat(n): else: return ConValue(z, [], []) +def build_nat(n): + assert n >= 0 + ret = z() + while n > 0: + ret = s(ret) + n = n - 1 + return ret + +def to_list(l): + assert isinstance(l, ConValue) + val = l + ret = [] + while True: + if val.con.name_hint == 'cons': + ret.append(val.fields[0]) + val = val.fields[1] + else: + assert val.con.name_hint == 'nil' + break + return ret def test_nat_value(): assert count(make_nat(10)) == 10 @@ -59,6 +79,8 @@ def test_nat_constructor(): def test_double(): assert mod[double].checked_type == relay.FuncType([nat()], nat()) + res = intrp.evaluate(double(s(z()))) + assert count(res) == 2 def test_add(): @@ -76,6 +98,8 @@ def test_list_constructor(): def test_length(): a = relay.TypeVar("a") assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) + assert count(res) == 3 def test_map(): @@ -85,6 +109,13 @@ def test_map(): rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b]) assert lhs == rhs + x = relay.Var("x") + add_one = relay.Function([x], s(x)) + res = intrp.evaluate(map(add_one, cons(z(), cons(z(), nil())))) + ones = to_list(res) + assert len(ones) == 2 + assert count(ones[0]) == 1 and count(ones[1]) == 1 + def test_foldl(): a = relay.TypeVar("a") @@ -93,6 +124,17 @@ def test_foldl(): rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) assert lhs == rhs + x = relay.Var("x") + y = relay.Var("y") + rev = relay.Function([y, x], cons(x, y)) + res = intrp.evaluate(foldl(rev, nil(), + cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + reversed = to_list(res) + assert len(reversed) == 3 + assert count(reversed[0]) == 3 and count(reversed[1]) == 2 and count(reversed[2]) == 1 + def test_foldr(): a = relay.TypeVar("a") @@ -101,9 +143,22 @@ def test_foldr(): rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) assert lhs == rhs + x = relay.Var("x") + y = relay.Var("y") + identity = relay.Function([x, y], cons(x, y)) + res = intrp.evaluate(foldr(identity, nil(), + cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + same = to_list(res) + assert len(same) == 3 + assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3 + def test_sum(): assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) + res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) + assert count(res) == 3 def test_tmap(): @@ -113,6 +168,7 @@ def test_tmap(): rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) assert lhs == rhs + def test_size(): a = relay.TypeVar("a") lhs = mod[size].checked_type From 4f745451ad739f73d951b44b2f67218d13f26752 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 18:57:59 -0800 Subject: [PATCH 21/61] Refactor kind check to return the kind, like typechecking --- include/tvm/relay/pass.h | 4 +- include/tvm/relay/type.h | 3 +- python/tvm/relay/ir_pass.py | 10 +- python/tvm/relay/ty.py | 1 + src/relay/ir/module.cc | 3 +- src/relay/pass/kind_check.cc | 169 +++++++++++---------- tests/python/relay/test_pass_check_kind.py | 49 ++++-- 7 files changed, 140 insertions(+), 99 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index d55386885d28..b87f9319a3d3 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -56,9 +56,9 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, * \param t The type to check. * \param mod The global module. * - * \return true if the rules are satisified otherwise false + * \return The kind of the passed type. */ -TVM_DLL bool KindCheck(const Type& t, const Module& mod); +TVM_DLL Kind KindCheck(const Type& t, const Module& mod); /*! \brief Compare two expressions for structural equivalence. * diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index ad722ce85e41..09c4161c6d96 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -104,7 +104,8 @@ enum Kind : int { kType = 0, kShapeVar = 1, kBaseType = 2, - kShape = 3 + kShape = 3, + kConstraint = 4 }; /*! diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index e274926f0b7a..90d038ebc784 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -108,7 +108,7 @@ def well_formed(expr): def check_kind(t, mod=None): - """Check that the type is well kinded. + """Check that the type is well kinded and return the kind. For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. Parameters @@ -121,15 +121,15 @@ def check_kind(t, mod=None): Returns ------- - well_kinded : bool - whether the input type is well kinded. + kind : Kind + the kind of t Examples -------- .. code:: python - assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) - assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type """ if mod is not None: return _ir_pass.check_kind(t, mod) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index ee95a82fb90b..9eef6724ced6 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -88,6 +88,7 @@ class Kind(IntEnum): ShapeVar = 1 BaseType = 2 Shape = 3 + Constraint = 4 @register_relay_node class TypeVar(Type): diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index a64d145ccdaf..c8e32505d979 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -98,8 +98,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { // need to kind check at the end because the check can look up // a definition potentially - CHECK(KindCheck(type, GetRef(this))) - << "Kind-checking fails for the type data given: " << type; + KindCheck(type, GetRef(this)); } void ModuleNode::Update(const GlobalVar& var, const Function& func) { diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 57fa5d529459..b28737bbfb37 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -21,143 +21,156 @@ namespace relay { using namespace tvm::runtime; -struct KindChecker : TypeVisitor { - bool valid; +struct KindChecker : TypeFunctor { const Module& mod; - explicit KindChecker(const Module& mod) : valid(true), mod(mod) {} + explicit KindChecker(const Module& mod) : mod(mod) {} - // checks if t is an incomplete node of kind k or a type param of kind k - bool MatchKind(const Type& t, Kind k) { - if (const IncompleteTypeNode* tv = t.as()) { - return tv->kind == k; - } - - if (const TypeVarNode* tp = t.as()) { - return tp->kind == k; - } - - if (const GlobalTypeVarNode* gtp = t.as()) { - return gtp->kind == k; - } + Kind VisitType_(const IncompleteTypeNode* op) override { + return op->kind; + } - return false; + Kind VisitType_(const TypeVarNode* op) override { + return op->kind; } - bool IsTypeKind(const Type& t) { - if (MatchKind(t, Kind::kType)) { - return true; - } + Kind VisitType_(const GlobalTypeVarNode* op) override { + return op->kind; + } - return t.as_derived() || t.as() || t.as() - || t.as(); + Kind VisitType_(const TensorTypeNode* op) override { + return Kind::kType; } - void VisitType_(const TupleTypeNode* op) override { + Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + Kind k = this->VisitType(t); + CHECK(k == Kind::kType) + << "All types in tuple type must be of a type kind but " + << t << " in " << GetRef(op) << " is of kind " << k; } + return Kind::kType; } - void VisitType_(const FuncTypeNode* op) override { + Kind VisitType_(const FuncTypeNode* op) override { // Func types should only take normal types for arguments // and only return a normal type. They should also have // well-formed constraints + FuncType ft = GetRef(op); for (const Type& t : op->arg_types) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + Kind k = this->VisitType(t); + CHECK(k == Kind::kType) + << "Function parameters must be of the type kind but parameter " + << t << " of " << ft << " is of kind " << k; } + Kind ret_kind = this->VisitType(ft->ret_type); + CHECK(ret_kind == Kind::kType) + << "The function return type must be of the type kind but " + << ft->ret_type << " of " << ft << " is of kind " << ret_kind; + for (const TypeConstraint& tc : op->type_constraints) { - this->VisitType(tc); - if (!valid) { - return; - } + Kind k = this->VisitType(tc); + CHECK(k == Kind::kConstraint) + << "All function type constraints are of the constraint kind but " + << tc << " of " << ft << " is of kind " << k; } - this->VisitType(op->ret_type); - valid = valid && IsTypeKind(op->ret_type); + return Kind::kType; } - void VisitType_(const RefTypeNode* op) override { - // tuples should only contain normal types - this->VisitType(op->value); - valid = valid && IsTypeKind(op->value); + Kind VisitType_(const RefTypeNode* op) override { + // ref types should only contain normal types + Kind k = this->VisitType(op->value); + CHECK(k == Kind::kType) + << "The value inside a ref must be of the type kind but " + << op->value << " of " << GetRef(op) << " is of kind " << k; + return Kind::kType; } - void VisitType_(const TypeRelationNode* op) override { + Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + Kind k = this->VisitType(t); + CHECK(k == Kind::kType) + << "All arguments to type relations must be of the type kind but " + << t << " of " << GetRef(op) << " is of kind " << k; } + return Kind::kConstraint; } - void VisitType_(const TypeCallNode* op) override { + Kind VisitType_(const TypeCallNode* op) override { // type call func should be a global type var, args should be type + TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); - valid = valid && gtv != nullptr && IsTypeKind(op->func); - if (!valid) { - return; - } + CHECK(gtv != nullptr) + << "Type call must be calling a global type var"; + + Kind func_kind = this->VisitType(op->func); + CHECK(func_kind == Kind::kType) + << "Type calls must call a global type var of the type kind but " + << op->func << " of " << tc << " is of kind " << func_kind; + for (const Type& t : op->args) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + Kind k = this->VisitType(t); + CHECK(k == Kind::kType) + << "Type call arguments must be of the type kind but " + << t << " of " << tc << " is of kind " << k; } // finally we need to check the module to check the number of type params auto var = GetRef(gtv); auto data = mod->LookupDef(var); - valid = valid && data->tv.size() == op->args.size(); + CHECK(data->tv.size() == op->args.size()) + << "Incorrect arity in " << tc + << " Expected: " << data->tv.size() + << " Given: " << op->args.size(); + return Kind::kType; } - void VisitType_(const TypeDataNode* op) override { + Kind VisitType_(const TypeDataNode* op) override { // Constructors can reference the header var, but no other GlobalTypeVars. // In theory, a TypeData could be nested, so the header scope // should be tracked recursively, but it is unclear that we need // to support it. - valid = valid && op->header->kind == Kind::kType; + TypeData td = GetRef(op); + Kind header_kind = this->VisitType(op->header); + CHECK(header_kind == Kind::kType) + << "The header for ADT type data must be of the type kind but " + << op->header << " of " << td << " is of kind " << header_kind; + for (const auto& var : op->tv) { - valid = valid && IsTypeKind(var); - if (!valid) { - return; - } + Kind k = this->VisitType(var); + CHECK(k == Kind::kType) + << "All type params for ADT type data must be of the type kind but " + << var << " of " << td << " is of kind " << k; } + for (const auto& con : op->constructors) { - valid = valid && con->belong_to.same_as(op->header); + CHECK(con->belong_to.same_as(op->header)) + << "Constructors should have same global type var as type data"; + for (const Type& t : con->inp) { - valid = valid && IsTypeKind(t); + Kind k = this->VisitType(t); + CHECK(k == Kind::kType) + << "All inputs to a constructor must be of the type kind but" + << t << " of " << con << " is of kind " << k; if (const auto* gtv = t.as()) { - valid = valid && GetRef(gtv).same_as(op->header); - } - if (!valid) { - return; + CHECK(GetRef(gtv).same_as(op->header)) + << "A global type var taken by a constructor must be the one the constructor makes"; } } } + return Kind::kType; } - bool Check(const Type& t) { - this->VisitType(t); - return valid; + Kind Check(const Type& t) { + return this->VisitType(t); } }; -bool KindCheck(const Type& t, const Module& mod) { +Kind KindCheck(const Type& t, const Module& mod) { KindChecker kc(mod); return kc.Check(t); } diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 5ead501157c5..9b81e8225ace 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -1,6 +1,19 @@ import tvm from tvm import relay from tvm.relay.ir_pass import check_kind +from nose.tools import raises + + +def test_typevar_kind(): + # returns the same kind + tp1 = relay.TypeVar('tp1', relay.Kind.Type) + tp2 = relay.TypeVar('tp2', relay.Kind.Shape) + tp3 = relay.TypeVar('tp3', relay.Kind.Constraint) + + assert check_kind(tp1) == relay.Kind.Type + assert check_kind(tp2) == relay.Kind.Shape + assert check_kind(tp3) == relay.Kind.Constraint + def test_tuple_kind(): # only contain type kinds @@ -10,7 +23,7 @@ def test_tuple_kind(): fields = tvm.convert([tp, tf, tt]) tup_ty = relay.TupleType(fields) - assert check_kind(tup_ty) + assert check_kind(tup_ty) == relay.Kind.Type def test_func_kind(): @@ -30,7 +43,7 @@ def test_func_kind(): ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert check_kind(tf) + assert check_kind(tf) == relay.Kind.Type def test_relation_kind(): @@ -41,9 +54,10 @@ def test_relation_kind(): args = tvm.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) - assert check_kind(tr) + assert check_kind(tr) == relay.Kind.Constraint +@raises(tvm._ffi.base.TVMError) def test_invalid_tuple_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) @@ -51,9 +65,10 @@ def test_invalid_tuple_kind(): fields = tvm.convert([tp1, tp2, tp3]) tup_ty = relay.TupleType(fields) - assert not check_kind(tup_ty) + check_kind(tup_ty) +@raises(tvm._ffi.base.TVMError) def test_invalid_func_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) @@ -65,51 +80,63 @@ def test_invalid_func_kind(): ret_type = tp3 tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert not check_kind(tf) + check_kind(tf) +@raises(tvm._ffi.base.TVMError) def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) - tr = relay.TypeRelation(None, args, 2, None) - assert not check_kind(tr) + func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + tr = relay.TypeRelation(func, args, 2, None) + check_kind(tr) +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_ret_type(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + check_kind(tf) + +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_arg_types(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + check_kind(tf) + +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_tuple(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) - assert not check_kind(tf) + check_kind(tf) +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) - tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) + func = tvm.get_env_func("tvm.relay.type_relation.Identity") + tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) - assert not check_kind(tf) + check_kind(tf) +@raises(tvm._ffi.base.TVMError) def test_tuple_with_invalid_func(): tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') @@ -117,7 +144,7 @@ def test_tuple_with_invalid_func(): tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) - assert not check_kind(tup_ty) + check_kind(tup_ty) if __name__ == "__main__": From 276e02866359120a3c0d69f4218dab67bae05f66 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 18:58:15 -0800 Subject: [PATCH 22/61] Fix invalid typecall in test --- tests/python/relay/test_typecall.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index c45e435de27c..6f002c438a52 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -14,7 +14,10 @@ def test_id_type(): mod = relay.Module() id_type = relay.GlobalTypeVar("id") a = relay.TypeVar("a") - make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) + mod[id_type] = relay.TypeData(id_type, [a], []) + + b = relay.TypeVar("b") + make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) t = relay.scalar_type("float32") b = relay.Var("b", t) assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) From 7963e7e9e8d36a3b4fcf40175c79d138f0a33b24 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 19:29:54 -0800 Subject: [PATCH 23/61] Add kind check to type inference, do not use nulls in func_type_annotation()! --- src/relay/ir/expr.cc | 9 +++++++-- src/relay/pass/type_infer.cc | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index bc6eee3ebc03..6d8927624009 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -130,9 +130,14 @@ Function FunctionNode::make(tvm::Array params, FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { - param_types.push_back(param->type_annotation); + Type param_type = (param->type_annotation.defined()) ? param->type_annotation + : IncompleteTypeNode::make(Kind::kType); + param_types.push_back(param_type); } - return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); + + Type ret_type = (this->ret_type.defined()) ? this->ret_type + : IncompleteTypeNode::make(Kind::kType); + return FuncTypeNode::make(param_types, ret_type, this->type_params, {}); } bool FunctionNode::IsPrimitive() const { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d89ba234660c..52cdb6db3e20 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -142,6 +142,7 @@ class TypeInferencer : private ExprFunctor, return it->second.checked_type; } Type ret = this->VisitExpr(expr); + KindCheck(ret, mod_); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -441,6 +442,7 @@ class TypeInferencer : private ExprFunctor, if (f->ret_type.defined()) { rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } + CHECK(rtype.defined()); auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); } From e2d6219a5dbfa0d14f7a70d5e556c5770046d752 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 22 Jan 2019 19:31:00 -0800 Subject: [PATCH 24/61] Redundant whitespace --- src/relay/ir/expr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6d8927624009..29fe98ba78f5 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -134,7 +134,7 @@ FuncType FunctionNode::func_type_annotation() const { : IncompleteTypeNode::make(Kind::kType); param_types.push_back(param_type); } - + Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteTypeNode::make(Kind::kType); return FuncTypeNode::make(param_types, ret_type, this->type_params, {}); From 40c74108f393d6d80d33f365f855fc7ef619f731 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 12:16:09 -0800 Subject: [PATCH 25/61] Make TypeData a separate kind --- include/tvm/relay/type.h | 3 ++- python/tvm/relay/ty.py | 1 + src/relay/ir/module.cc | 3 ++- src/relay/pass/kind_check.cc | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 09c4161c6d96..82d84358a9c0 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -105,7 +105,8 @@ enum Kind : int { kShapeVar = 1, kBaseType = 2, kShape = 3, - kConstraint = 4 + kConstraint = 4, + kTypeData = 5 }; /*! diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 9eef6724ced6..c8ae9c558a1e 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -89,6 +89,7 @@ class Kind(IntEnum): BaseType = 2 Shape = 3 Constraint = 4 + TypeData = 5 @register_relay_node class TypeVar(Type): diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index c8e32505d979..920c15b2e756 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -98,7 +98,8 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { // need to kind check at the end because the check can look up // a definition potentially - KindCheck(type, GetRef(this)); + CHECK(KindCheck(type, GetRef(this)) == Kind::kTypeData) + << "Invalid or malformed typedata given to module: " << type; } void ModuleNode::Update(const GlobalVar& var, const Function& func) { diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index b28737bbfb37..a83b2834b1c1 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -162,7 +162,7 @@ struct KindChecker : TypeFunctor { } } } - return Kind::kType; + return Kind::kTypeData; } Kind Check(const Type& t) { From 673bcd6f699019a0ebdb61a46506027d09b30df0 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 13:04:30 -0800 Subject: [PATCH 26/61] Make ADT handles a separate kind too, document calling convention better --- include/tvm/relay/adt.h | 8 ++++++++ include/tvm/relay/type.h | 3 ++- python/tvm/relay/adt.py | 8 +++++++- python/tvm/relay/ty.py | 7 ++++--- src/relay/pass/kind_check.cc | 8 ++++---- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 7ac29fee2896..d64f9614b42e 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -144,6 +144,14 @@ RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); /*! * \brief Stores all data for an Algebraic Data Type (ADT). + * + * In particular, it stores the handle (global type var) for an ADT + * and the constructors used to build it and is kept in the module. Note + * that type parameters are also indicated in the type data: this means that + * for any instance of an ADT, the type parameters must be indicated. That is, + * an ADT definition is treated as a type-level function, so an ADT handle + * must be wrapped in a TypeCall node that instantiates the type-level arguments. + * The kind checker enforces this. */ class TypeData; /*! \brief TypeData container node */ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 82d84358a9c0..6c164ab6bcea 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -106,7 +106,8 @@ enum Kind : int { kBaseType = 2, kShape = 3, kConstraint = 4, - kTypeData = 5 + kAdtHandle = 5, + kTypeData = 6 }; /*! diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index d3fd94359d8d..ad276e0c9fb1 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -114,7 +114,13 @@ def __call__(self, *args): @register_relay_node class TypeData(Type): - """Stores the definition for an Algebraic Data Type (ADT) in Relay.""" + """Stores the definition for an Algebraic Data Type (ADT) in Relay. + + Note that ADT definitions are treated as type-level functions because + the type parameters need to be given for an instance of the ADT. Thus, + any global type var that is an ADT header needs to be wrapped in a + type call that passes in the type params. + """ def __init__(self, header, tv, constructors): """Defines a TypeData object. diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index c8ae9c558a1e..7ac2360aa473 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -89,7 +89,8 @@ class Kind(IntEnum): BaseType = 2 Shape = 3 Constraint = 4 - TypeData = 5 + AdtHandle = 5 + TypeData = 6 @register_relay_node class TypeVar(Type): @@ -128,7 +129,7 @@ class GlobalTypeVar(Type): stored in the environment. """ - def __init__(self, var, kind=Kind.Type): + def __init__(self, var, kind=Kind.AdtHandle): """Construct a GlobalTypeVar. Parameters @@ -136,7 +137,7 @@ def __init__(self, var, kind=Kind.Type): var: tvm.Var The tvm.Var which backs the type parameter. kind: Kind, optional - The kind of the type parameter, Kind.Type by default. + The kind of the type parameter, Kind.AdtHandle by default. Returns ------- diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index a83b2834b1c1..effb4b2f5c54 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -108,8 +108,8 @@ struct KindChecker : TypeFunctor { << "Type call must be calling a global type var"; Kind func_kind = this->VisitType(op->func); - CHECK(func_kind == Kind::kType) - << "Type calls must call a global type var of the type kind but " + CHECK(func_kind == Kind::kAdtHandle) + << "Type calls must call a global type var that is an ADT handle but " << op->func << " of " << tc << " is of kind " << func_kind; for (const Type& t : op->args) { @@ -136,8 +136,8 @@ struct KindChecker : TypeFunctor { // to support it. TypeData td = GetRef(op); Kind header_kind = this->VisitType(op->header); - CHECK(header_kind == Kind::kType) - << "The header for ADT type data must be of the type kind but " + CHECK(header_kind == Kind::kAdtHandle) + << "The header for ADT type data must be an ADT handle but " << op->header << " of " << td << " is of kind " << header_kind; for (const auto& var : op->tv) { From 5e61378d34c59332fac1bead65dee1f926ac8958 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 13:45:35 -0800 Subject: [PATCH 27/61] Remove nats and tree from prelude, move to test, document prelude --- python/tvm/relay/prelude.py | 102 +++++++++++++-------------------- tests/python/relay/test_adt.py | 63 +++++++++++++------- 2 files changed, 80 insertions(+), 85 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 35235e0f7190..c0dba067d2e3 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -1,49 +1,30 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name -"""Include some preloaded term/type definitions.""" +"""Adds certain standard global functions and ADT definitions to the module.""" from .ty import GlobalTypeVar, TypeVar, FuncType from .expr import Var, Function, GlobalVar from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard class Prelude: - """Contain standard definitions.""" - def __init__(self, mod): - self.mod = mod - self.nat = GlobalTypeVar("nat") - self.z = Constructor("z", [], self.nat) - self.s = Constructor("s", [self.nat()], self.nat) - mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) - - self.double = GlobalVar("double") - x = Var("x", self.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.z()) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.s(self.s(self.double(y)))) - mod[self.double] = Function([x], Match(x, [z_case, s_case])) - - self.add = GlobalVar("add") - x = Var("x", self.nat()) - y = Var("y", self.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(self.z), y) - s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), self.s(self.add(a, y))) - mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) + """Contains standard definitions.""" + def define_list_adt(self): + """Defines a LISP-style list ADT. An empty list is + represented by nil(). A member x can be appended to the + front of a list l via the constructor cons(x, l).""" self.l = GlobalTypeVar("list") a = TypeVar("a") self.nil = Constructor("nil", [], self.l) self.cons = Constructor("cons", [a, self.l(a)], self.l) - mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) - self.length = GlobalVar("length") - a = TypeVar("a") - x = Var("x", self.l(a)) - y = Var("y") - nil_case = Clause(PatternConstructor(self.nil), self.z()) - cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), - self.s(self.length(y))) - mod[self.length] = Function([x], Match(x, [nil_case, cons_case]), None, [a]) + def define_list_map(self): + """Defines a function for mapping a function over a list's + elements. That is, map(f, l) returns a new list where + the ith member is f applied to the ith member of l. + map(f, l) : fun(fun(a) -> b, list) -> list + """ self.map = GlobalVar("map") a = TypeVar("a") b = TypeVar("b") @@ -54,8 +35,16 @@ def __init__(self, mod): nil_case = Clause(PatternConstructor(self.nil), self.nil()) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), self.cons(f(y), self.map(f, z))) - mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), None, [a, b]) + self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), None, [a, b]) + + def define_list_foldl(self): + """Defines a left-way fold over a list. + foldl(f, z, l) : fun(fun(b, a) -> b, b, list) -> b + + foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) + evaluates to f(...f(f(f(z, a1), a2), a3)...) + """ self.foldl = GlobalVar("foldl") a = TypeVar("a") b = TypeVar("b") @@ -67,13 +56,17 @@ def __init__(self, mod): nil_case = Clause(PatternConstructor(self.nil), av) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), self.foldl(f, f(av, y), z)) - mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), None, [a, b]) + self.mod[self.foldl] = Function([f, av, bv], + Match(bv, [nil_case, cons_case]), None, [a, b]) - self.tree = GlobalTypeVar("tree") - a = TypeVar("a") - self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) - mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + def define_list_foldr(self): + """Defines a right-way fold over a list. + + foldr(f, l, z) : fun(fun(a, b) -> b, list, b) -> b + foldr(f, cons(a1, cons(a2, cons(..., cons(an, nil)))), z) + evalutes to f(a1, f(a2, f(..., f(an, z)))...) + """ self.foldr = GlobalVar("foldr") a = TypeVar("a") b = TypeVar("b") @@ -85,29 +78,12 @@ def __init__(self, mod): nil_case = Clause(PatternConstructor(self.nil), bv) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), f(y, self.foldr(f, bv, z))) - mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) + self.mod[self.foldr] = Function([f, bv, av], + Match(av, [nil_case, cons_case]), None, [a, b]) - self.sum = GlobalVar("sum") - a = Var("a", self.l(self.nat())) - mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) - - self.tmap = GlobalVar("tmap") - a = TypeVar("a") - b = TypeVar("b") - t = Var("t", self.tree(a)) - f = Var("f", FuncType([a], b)) - x = Var("x", self.tree(a)) - y = Var("y") - z = Var("z") - rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), - self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) - mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) - - self.size = GlobalVar("size") - a = TypeVar("a") - t = Var("t", self.tree(a)) - x = Var("x", self.tree(a)) - z = Var("z") - rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), - self.s(self.sum(self.map(Function([x], self.size(x)), z)))) - mod[self.size] = Function([t], Match(t, [rose_case]), self.nat(), [a]) + def __init__(self, mod): + self.mod = mod + self.define_list_adt() + self.define_list_map() + self.define_list_foldl() + self.define_list_foldr() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index d501c7d74a62..9ef1fe1b62d5 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -1,5 +1,9 @@ import tvm from tvm import relay +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.adt import Constructor, TypeData, Clause, Match +from tvm.relay.adt import PatternConstructor, PatternVar, PatternWildcard from tvm.relay.ir_pass import infer_type from tvm.relay.backend.interpreter import Value, TupleValue, ConValue from tvm.relay import testing, create_executor @@ -10,6 +14,43 @@ ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") +# defines peano nats and related functions for testing purposes +def add_nat_definitions(): + p.nat = GlobalTypeVar("nat") + p.z = Constructor("z", [], p.nat) + p.s = Constructor("s", [p.nat()], p.nat) + mod[p.nat] = TypeData(p.nat, [], [p.z, p.s]) + + p.double = GlobalVar("double") + x = Var("x", p.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(p.z), p.z()) + s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(p.s(p.double(y)))) + mod[p.double] = Function([x], Match(x, [z_case, s_case])) + + p.add = GlobalVar("add") + x = Var("x", p.nat()) + y = Var("y", p.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(p.z), y) + s_case = Clause(PatternConstructor(p.s, [PatternVar(a)]), p.s(p.add(a, y))) + mod[p.add] = Function([x, y], Match(x, [z_case, s_case])) + + p.sum = GlobalVar("sum") + a = Var("a", p.l(p.nat())) + mod[p.sum] = Function([a], p.foldl(p.add, p.z(), a)) + + p.length = GlobalVar("length") + a = TypeVar("a") + x = Var("x", p.l(a)) + y = Var("y") + nil_case = Clause(PatternConstructor(p.nil), p.z()) + cons_case = Clause(PatternConstructor(p.cons, [PatternWildcard(), PatternVar(y)]), + p.s(p.length(y))) + mod[p.length] = Function([x], Match(x, [nil_case, cons_case]), None, [a]) + +add_nat_definitions() + z = p.z s = p.s nat = p.nat @@ -25,11 +66,6 @@ foldr = p.foldr sum = p.sum -tree = p.tree -rose = p.rose -tmap = p.tmap -size = p.size - # this is an example of using the adt value in python side def count(n): assert isinstance(n, ConValue) @@ -161,21 +197,6 @@ def test_sum(): assert count(res) == 3 -def test_tmap(): - a = relay.TypeVar("a") - b = relay.TypeVar("b") - lhs = mod[tmap].checked_type - rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) - assert lhs == rhs - - -def test_size(): - a = relay.TypeVar("a") - lhs = mod[size].checked_type - rhs = relay.FuncType([tree(a)], nat(), [a]) - assert lhs == rhs - - if __name__ == "__main__": test_nat_constructor() test_double() @@ -186,5 +207,3 @@ def test_size(): test_foldl() test_foldr() test_sum() - test_tmap() - test_size() From 3db3c64e70a5e9023f54638eb0b31ef2b72d4a9f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 14:44:01 -0800 Subject: [PATCH 28/61] Restore and document nat and tree to prelude, add more tree tests --- python/tvm/relay/prelude.py | 99 ++++++++++++++++++++++++++++++++++ tests/python/relay/test_adt.py | 96 +++++++++++++++++++-------------- 2 files changed, 154 insertions(+), 41 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index c0dba067d2e3..48fa559bbc14 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -81,9 +81,108 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) + def define_nat_adt(self): + """Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n.""" + self.nat = GlobalTypeVar("nat") + self.z = Constructor("z", [], self.nat) + self.s = Constructor("s", [self.nat()], self.nat) + self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) + + def define_nat_double(self): + """Defines a function that doubles a nat.""" + self.double = GlobalVar("double") + x = Var("x", self.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(self.z), self.z()) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), + self.s(self.s(self.double(y)))) + self.mod[self.double] = Function([x], Match(x, [z_case, s_case])) + + def define_nat_add(self): + """Defines a function that adds two nats.""" + self.add = GlobalVar("add") + x = Var("x", self.nat()) + y = Var("y", self.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(self.z), y) + s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), + self.s(self.add(a, y))) + self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) + + def define_list_sum(self): + """Defines a function that computes the sum of a list of nats.""" + self.sum = GlobalVar("sum") + a = Var("a", self.l(self.nat())) + self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + + def define_list_length(self): + """Defines a function that returns the length of a list as a nat""" + self.length = GlobalVar("length") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + nil_case = Clause(PatternConstructor(self.nil), self.z()) + cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), + self.s(self.length(y))) + self.mod[self.length] = Function([x], + Match(x, [nil_case, cons_case]), None, [a]) + + def define_tree_adt(self): + """Defines a tree ADT. A tree can contain any type. + It has only one constructor, rose(x, l), where x is the content + of that point of the tree and l is a list of more trees of the + same type. A leaf is thus rose(x, nil()). + """ + self.tree = GlobalTypeVar("tree") + a = TypeVar("a") + self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) + self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + + def define_tree_map(self): + """Defines a function that maps over a tree. The function + is applied to each subtree's contents. + + Signature: fun(f : fun(a) -> b, t : tree) -> tree + """ + self.tmap = GlobalVar("tmap") + a = TypeVar("a") + b = TypeVar("b") + t = Var("t", self.tree(a)) + f = Var("f", FuncType([a], b)) + x = Var("x", self.tree(a)) + y = Var("y") + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), + self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) + self.mod[self.tmap] = Function([f, t], + Match(t, [rose_case]), self.tree(b), [a, b]) + + def define_tree_size(self): + """Defines a function that computes the size of a tree as a nat.""" + self.size = GlobalVar("size") + a = TypeVar("a") + t = Var("t", self.tree(a)) + x = Var("x", self.tree(a)) + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), + self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + self.mod[self.size] = Function([t], + Match(t, [rose_case]), self.nat(), [a]) + def __init__(self, mod): self.mod = mod self.define_list_adt() self.define_list_map() self.define_list_foldl() self.define_list_foldr() + + self.define_nat_adt() + self.define_nat_double() + self.define_nat_add() + self.define_list_length() + self.define_list_sum() + + self.define_tree_adt() + self.define_tree_map() + self.define_tree_size() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 9ef1fe1b62d5..5c624da53739 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -1,9 +1,5 @@ import tvm from tvm import relay -from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType -from tvm.relay.expr import Var, Function, GlobalVar -from tvm.relay.adt import Constructor, TypeData, Clause, Match -from tvm.relay.adt import PatternConstructor, PatternVar, PatternWildcard from tvm.relay.ir_pass import infer_type from tvm.relay.backend.interpreter import Value, TupleValue, ConValue from tvm.relay import testing, create_executor @@ -14,43 +10,6 @@ ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") -# defines peano nats and related functions for testing purposes -def add_nat_definitions(): - p.nat = GlobalTypeVar("nat") - p.z = Constructor("z", [], p.nat) - p.s = Constructor("s", [p.nat()], p.nat) - mod[p.nat] = TypeData(p.nat, [], [p.z, p.s]) - - p.double = GlobalVar("double") - x = Var("x", p.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(p.z), p.z()) - s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(p.s(p.double(y)))) - mod[p.double] = Function([x], Match(x, [z_case, s_case])) - - p.add = GlobalVar("add") - x = Var("x", p.nat()) - y = Var("y", p.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(p.z), y) - s_case = Clause(PatternConstructor(p.s, [PatternVar(a)]), p.s(p.add(a, y))) - mod[p.add] = Function([x, y], Match(x, [z_case, s_case])) - - p.sum = GlobalVar("sum") - a = Var("a", p.l(p.nat())) - mod[p.sum] = Function([a], p.foldl(p.add, p.z(), a)) - - p.length = GlobalVar("length") - a = TypeVar("a") - x = Var("x", p.l(a)) - y = Var("y") - nil_case = Clause(PatternConstructor(p.nil), p.z()) - cons_case = Clause(PatternConstructor(p.cons, [PatternWildcard(), PatternVar(y)]), - p.s(p.length(y))) - mod[p.length] = Function([x], Match(x, [nil_case, cons_case]), None, [a]) - -add_nat_definitions() - z = p.z s = p.s nat = p.nat @@ -66,6 +25,11 @@ def add_nat_definitions(): foldr = p.foldr sum = p.sum +tree = p.tree +rose = p.rose +tmap = p.tmap +size = p.size + # this is an example of using the adt value in python side def count(n): assert isinstance(n, ConValue) @@ -103,6 +67,17 @@ def to_list(l): break return ret +def tree_to_dict(t): + assert isinstance(t, ConValue) + ret = {} + assert t.con.name_hint == 'rose' + ret['member'] = t.fields[0] + ret['children'] = [] + for subtree in to_list(t.fields[1]): + l = tree_to_dict(subtree) + ret['children'].append(l) + return ret + def test_nat_value(): assert count(make_nat(10)) == 10 @@ -197,6 +172,43 @@ def test_sum(): assert count(res) == 3 +def test_tmap(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[tmap].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) + assert lhs == rhs + + x = relay.Var("x") + add_one = relay.Function([x], s(x)) + res = intrp.evaluate(tmap(add_one, + rose(z(), + cons(rose(z(), nil()), + cons(rose(z(), nil()), + nil()))))) + + tree_dict = tree_to_dict(res) + assert count(tree_dict['member']) == 1 + assert len(tree_dict['children']) == 2 + for subtree in tree_dict['children']: + assert count(subtree['member']) == 1 + assert len(subtree['children']) == 0 + + +def test_size(): + a = relay.TypeVar("a") + lhs = mod[size].checked_type + rhs = relay.FuncType([tree(a)], nat(), [a]) + assert lhs == rhs + + root = rose(z(), cons(rose(z(), nil()), + cons(rose(z(), nil()), + nil()))) + t = rose(z(), cons(root, cons(root, cons(root, nil())))) + res = intrp.evaluate(size(t)) + assert count(res) == 10 + + if __name__ == "__main__": test_nat_constructor() test_double() @@ -207,3 +219,5 @@ def test_sum(): test_foldl() test_foldr() test_sum() + test_tmap() + test_size() From 0041b4601c0d146d4d05869854c9bc10298407ae Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 16:31:28 -0800 Subject: [PATCH 29/61] Add alpha equality tests for match cases, fix variable binding bug --- src/relay/ir/alpha_equal.cc | 8 +- src/relay/ir/hash.cc | 2 +- tests/python/relay/test_pass_alpha_equal.py | 96 +++++++++++++++++++++ 3 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 2e3d2223181c..8ad0cc0c684a 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -425,18 +425,18 @@ class AlphaEqualHandler: return VisitPattern(l, r); } - bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) { + bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) final { return r.as(); } - bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) { + bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) final { if (const auto* r = e2.as()) { - return ExprEqual(op->var, r->var); + return MergeVarDecl(op->var, r->var); } return false; } - bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) { + bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final { const auto* r = e2.as(); if (r == nullptr || !ExprEqual(op->con, r->con) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 75a2dc0aea28..9c991e76ac60 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -368,7 +368,7 @@ class RelayHashHandler: size_t VisitPattern_(const PatternVarNode* pvn) final { size_t hash = std::hash()(PatternVarNode::_type_key); - hash = Combine(hash, ExprHash(pvn->var)); + hash = Combine(hash, BindVar(pvn->var)); return hash; } diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 5158d5c7cc9c..202b2cb4bdd8 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -171,6 +171,29 @@ def test_type_relation_alpha_equal(): assert bigger != diff_num_inputs +def test_type_call_alpha_equal(): + h1 = relay.GlobalTypeVar("h1") + h2 = relay.GlobalTypeVar("h2") + t1 = relay.TensorType((1, 2), "float32") + t2 = relay.TensorType((1, 2, 3), "float32") + t3 = relay.TensorType((1, 2, 3, 4), "float32") + t4 = relay.TensorType((), "float32") + + tc = relay.TypeCall(h1, [t1, t2, t3]) + same = relay.TypeCall(h1, [t1, t2, t3]) + + different_func = relay.TypeCall(h2, [t1, t2, t3]) + different_arg = relay.TypeCall(h1, [t1, t2, t4]) + fewer_args = relay.TypeCall(h1, [t1, t2]) + more_args = relay.TypeCall(h1, [t1, t2, t3, t4]) + different_order_args = relay.TypeCall(h1, [t3, t2, t1]) + + assert tc == same + assert tc != different_func + assert tc != fewer_args + assert tc != more_args + assert tc != different_order_args + def test_constant_alpha_equal(): x = relay.const(1) @@ -453,6 +476,79 @@ def test_if_alpha_equal(): assert not alpha_equal(if_sample, different_false) +def test_constructor_alpha_equal(): + # smoke test: it should be pointer equality + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + assert alpha_equal(p.nil, p.nil) + assert alpha_equal(p.cons, p.cons) + assert not alpha_equal(p.nil, p.cons) + + +def test_match_alpha_equal(): + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + x = relay.Var('x') + y = relay.Var('y') + nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil()) + cons_case = relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(x), + relay.PatternVar(y)]), + p.cons(x, y)) + + z = relay.Var('z') + a = relay.Var('a') + equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(z), + relay.PatternVar(a)]), + p.cons(z, a)) + + data = p.cons(p.z(), p.cons(p.z(), p.nil())) + + match = relay.Match(data, [nil_case, cons_case]) + equivalent = relay.Match(data, [nil_case, equivalent_cons]) + empty = relay.Match(data, []) + no_cons = relay.Match(data, [nil_case]) + no_nil = relay.Match(data, [cons_case]) + different_data = relay.Match(p.nil(), [nil_case, cons_case]) + different_order = relay.Match(data, [cons_case, nil_case]) + different_nil = relay.Match(data, [ + relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), + cons_case + ]) + different_cons = relay.Match(data, [ + nil_case, + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternWildcard(), + relay.PatternWildcard()]), + p.nil()) + ]) + another_case = relay.Match(data, [ + nil_case, + cons_case, + relay.Clause(relay.PatternWildcard(), p.nil()) + ]) + wrong_constructors = relay.Match(data, [ + relay.Clause(relay.PatternConstructor(p.z), p.nil()), + relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), + p.cons(x, p.nil())) + ]) + + assert alpha_equal(match, match) + assert alpha_equal(match, equivalent) + assert not alpha_equal(match, no_cons) + assert not alpha_equal(match, no_nil) + assert not alpha_equal(match, empty) + assert not alpha_equal(match, different_data) + assert not alpha_equal(match, different_order) + assert not alpha_equal(match, different_nil) + assert not alpha_equal(match, different_cons) + assert not alpha_equal(match, another_case) + assert not alpha_equal(match, wrong_constructors) + + def test_op_alpha_equal(): # only checks names op1 = relay.op.get("add") From d232beba193131058af5c6eed414509c49ff062e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 16:49:36 -0800 Subject: [PATCH 30/61] Add more kind check tests for ADTs --- tests/python/relay/test_pass_check_kind.py | 53 ++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 9b81e8225ace..dafb789fda73 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -57,6 +57,31 @@ def test_relation_kind(): assert check_kind(tr) == relay.Kind.Constraint +def test_global_typevar_kind(): + v1 = relay.GlobalTypeVar('gtv1', relay.Kind.AdtHandle) + v2 = relay.GlobalTypeVar('gtv2', relay.Kind.Type) + + assert check_kind(v1) == relay.Kind.AdtHandle + assert check_kind(v2) == relay.Kind.Type + + +def test_typecall_kind(): + gtv = relay.GlobalTypeVar('gtv') + + mod = relay.Module() + data = relay.TypeData(gtv, [], []) + mod[gtv] = data + empty_call = relay.TypeCall(gtv, []) + assert check_kind(empty_call, mod) == relay.Kind.Type + + new_mod = relay.Module() + tv = relay.TypeVar('tv') + new_data = relay.TypeData(gtv, [tv], []) + new_mod[gtv] = new_data + call = relay.TypeCall(gtv, [relay.TupleType([])]) + assert check_kind(call, new_mod) == relay.Kind.Type + + @raises(tvm._ffi.base.TVMError) def test_invalid_tuple_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) @@ -95,6 +120,34 @@ def test_invalid_relation_kind(): check_kind(tr) +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_callee(): + # global type var must be an ADT handle + gtv = relay.GlobalTypeVar('v1', relay.Kind.Type) + check_kind(relay.TypeCall(gtv, [])) + + +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_args(): + # args must all be type kind + mod = relay.Module() + gtv = relay.GlobalTypeVar('v1') + data = relay.TypeData(gtv, [], []) + mod[gtv] = data + + check_kind(relay.TypeCall(gtv, [data])) + + +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_num_args(): + mod = relay.Module() + gtv = relay.GlobalTypeVar('v1') + tv = relay.TypeVar('tv') + data = relay.TypeData(gtv, [tv], []) + mod[gtv] = data + check_kind(relay.TypeCall(gtv, [])) + + @raises(tvm._ffi.base.TVMError) def test_func_with_invalid_ret_type(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) From 7c6d73777199b85937273de3e5426a7ac0d4a881 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 17:01:06 -0800 Subject: [PATCH 31/61] Add more tests for finding free or bound vars in match exprs --- tests/python/relay/test_pass_vars.py | 36 +++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py index c8d3d6d14992..afdaddca922a 100644 --- a/tests/python/relay/test_pass_vars.py +++ b/tests/python/relay/test_pass_vars.py @@ -65,6 +65,40 @@ def test_bound_vars(): assert_vars_match(bound_vars(f2), [x, y]) +def test_match_vars(): + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + x = relay.Var('x') + y = relay.Var('y') + z = relay.Var('z') + + match1 = relay.Match(p.nil(), [ + relay.Clause(relay.PatternConstructor(p.nil), z), + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(x), + relay.PatternVar(y)]), + p.cons(x, y)) + ]) + + match2 = relay.Match(p.nil(), [ + relay.Clause(relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), + relay.PatternVar(x) + ]), + y), + relay.Clause(relay.PatternWildcard(), z) + ]) + + assert_vars_match(bound_vars(match1), [x, y]) + assert_vars_match(free_vars(match1), [z]) + assert_vars_match(all_vars(match1), [z, x, y]) + + assert_vars_match(bound_vars(match2), [x]) + assert_vars_match(free_vars(match2), [y, z]) + assert_vars_match(all_vars(match2), [x, y, z]) + + def test_bound_type_vars(): a = relay.TypeVar("a") b = relay.TypeVar("b") @@ -127,7 +161,7 @@ def test_all_type_vars(): x = relay.Var("x", a) y = relay.Var("y", b) z = relay.Var("z", c) - + f1 = relay.Function([x], y, b, [a]) assert_vars_match(all_type_vars(f1), [a, b]) From 7322866d3754f4e65d5309b0f120cdd0a659f540 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 17:40:35 -0800 Subject: [PATCH 32/61] Add unification tests for type call --- tests/python/relay/test_type_solver.py | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 1e2fed0af1f8..8bcd912f841e 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -62,6 +62,30 @@ def test_unify_tuple(): assert unified == tup2 +def test_unify_global_type_var(): + # should only be able to unify if they're the same + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv') + unified = solver.Unify(gtv, gtv) + assert unified == gtv + + +def test_unify_typecall(): + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv') + + # yeah, typecalls are shaped like tuples so the same + # tests work out + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.TensorType((10, 20), "float32") + + tc1 = relay.ty.TypeCall(gtv, [t1, t2]) + tc2 = relay.ty.TypeCall(gtv, [t3, t3]) + unified = solver.Unify(tc1, tc2) + assert unified == tc2 + + def test_unify_functype(): solver = make_solver() t1 = relay.ty.IncompleteType() @@ -205,10 +229,49 @@ def test_bad_recursive_unification(): solver.Unify(t1, relay.ty.TupleType([t1, t1])) +@raises(tvm._ffi.base.TVMError) +def test_unify_invalid_global_typevars(): + solver = make_solver() + gtv1 = relay.GlobalTypeVar('gtv1') + gtv2 = relay.GlobalTypeVar('gtv2') + solver.Unify(gtv1, gtv2) + + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_typecall_var_unification(): + solver = make_solver() + gtv1 = relay.GlobalTypeVar('gtv1') + gtv2 = relay.GlobalTypeVar('gtv2') + + t1 = relay.IncompleteType() + t2 = relay.IncompleteType() + + tc1 = relay.TypeCall(gtv1, [t1]) + tc2 = relay.TypeCall(gtv2, [t2]) + solver.Unify(tc1, tc2) + + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_typecall_args_unification(): + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv1') + t1 = relay.IncompleteType() + t2 = relay.IncompleteType() + + tensor1 = relay.TensorType((1, 2, 3), "float32") + tensor2 = relay.TensorType((2, 3), "float32") + tensor3 = relay.TensorType((3,), "float32") + + tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2]) + tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3]) + solver.Unify(tc1, tc2) + + if __name__ == "__main__": test_bcast() test_backward_solving() test_unify_tuple() + test_unify_typecall() test_unify_functype() test_recursive_unify() test_unify_vars_under_tuples() @@ -216,3 +279,5 @@ def test_bad_recursive_unification(): test_backward_solving_after_child_update() test_incompatible_tuple_unification() test_bad_recursive_unification() + test_incompatible_typecall_var_unification() + test_incompatible_typecall_args_unification() From 5f3a2f4effea523290a09d96f0ade0254da67813 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 17:42:25 -0800 Subject: [PATCH 33/61] Update main() for alpha equality tests --- tests/python/relay/test_pass_alpha_equal.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 202b2cb4bdd8..ca86aaa3313e 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -587,6 +587,7 @@ def test_graph_equal(): test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() + test_type_call_alpha_equal() test_constant_alpha_equal() test_global_var_alpha_equal() test_tuple_alpha_equal() @@ -595,6 +596,8 @@ def test_graph_equal(): test_call_alpha_equal() test_let_alpha_equal() test_if_alpha_equal() + test_constructor_alpha_equal() + test_match_alpha_equal() test_op_alpha_equal() test_var_alpha_equal() test_graph_equal() From 90ee4057a568d6bb2be1c6f146f00f23aa2548af Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 18:07:53 -0800 Subject: [PATCH 34/61] Add simple type inference test cases for match exprs and ADT constructors --- tests/python/relay/test_type_infer.py | 54 +++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index eeefbc6c3051..b37dae3ab3ce 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -17,6 +17,16 @@ def assert_has_type(expr, typ, mod=relay.module.Module({})): checked_type, typ)) +# initializes simple ADT for tests +def initialize_box_adt(mod): + box = relay.GlobalTypeVar('box') + tv = relay.TypeVar('tv') + constructor = relay.Constructor('constructor', [tv], box) + data = relay.TypeData(box, [tv], [constructor]) + mod[box] = data + return (box, constructor) + + def test_monomorphic_let(): "Program: let x = 1; return x" sb = relay.ScopeBuilder() @@ -190,6 +200,47 @@ def test_equal(): assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) +def test_constructor_type(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + ct = relay.ir_pass.infer_type(constructor, mod) + a = relay.TypeVar('a') + expected = relay.FuncType([a], box(a), [a]) + assert ct.checked_type == expected + + +def test_constructor_call(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + box_unit = constructor(relay.Tuple([])) + box_constant = constructor(relay.const(0, 'float32')) + + ut = relay.ir_pass.infer_type(box_unit, mod) + ct = relay.ir_pass.infer_type(box_constant, mod) + assert ut.checked_type == box(relay.TupleType([])) + assert ct.checked_type == box(relay.TensorType((), 'float32')) + + +def test_adt_match(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + v = relay.Var('v', relay.TensorType((), 'float32')) + match = relay.Match(constructor(relay.const(0, 'float32')), + [relay.Clause( + relay.PatternConstructor(constructor, + [relay.PatternVar(v)]), + relay.Tuple([])), + # redundant but shouldn't matter to typechecking + relay.Clause(relay.PatternWildcard(), + relay.Tuple([]))]) + + mt = relay.ir_pass.infer_type(match, mod) + assert mt.checked_type == relay.TupleType([]) + + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -205,3 +256,6 @@ def test_equal(): test_global_var_recursion() test_equal() test_ref() + test_constructor_type() + test_constructor_call() + test_adt_match() From d4a54a1274cc3f4a4db721c7e22836052cb1b735 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 18:57:17 -0800 Subject: [PATCH 35/61] Add more ADT interpreter tests --- tests/python/relay/test_adt.py | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 5c624da53739..5d76d13746a9 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -209,6 +209,78 @@ def test_size(): assert count(res) == 10 +def test_wildcard_match_solo(): + x = relay.Var('x', nat()) + copy = relay.Function([x], + relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]), + nat()) + + res = intrp.evaluate(copy(s(s(s(z()))))) + assert count(res) == 3 + + +def test_wildcard_match_order(): + x = relay.Var('x', l(nat())) + y = relay.Var('y') + a = relay.Var('a') + return_zero = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternWildcard(), z()), + relay.Clause( + relay.PatternConstructor( + cons, [relay.PatternVar(y), relay.PatternVar(a)]), + y), + relay.Clause(relay.PatternConstructor(nil), s(z())) + ]), + nat()) + + res = intrp.evaluate(return_zero(cons(s(z()), nil()))) + # wildcard pattern is evaluated first + assert count(res) == 0 + + +def test_nested_matches(): + a = relay.TypeVar('a') + x = relay.Var('x', l(l(a))) + y = relay.Var('y') + w = relay.Var('w') + h = relay.Var('h') + t = relay.Var('t') + flatten = relay.GlobalVar('flatten') + + # flatten could be written using a fold, but this way has nested matches + inner_match = relay.Match( + y, [ + relay.Clause(relay.PatternConstructor(nil), flatten(w)), + relay.Clause(relay.PatternConstructor( + cons, [relay.PatternVar(h), relay.PatternVar(t)]), + cons(h, flatten(cons(t, w)))) + ]) + + mod[flatten] = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor(nil), nil()), + relay.Clause(relay.PatternConstructor( + cons, [relay.PatternVar(y), relay.PatternVar(w)]), + inner_match) + ]), l(a), [a]) + + first_list = cons(build_nat(1), cons(build_nat(2), + cons(build_nat(3), nil()))) + second_list = cons(build_nat(4), cons(build_nat(5), + cons(build_nat(6), nil()))) + final_list = cons(first_list, cons(second_list, nil())) + + res = intrp.evaluate(flatten(final_list)) + + flat = to_list(res) + assert len(flat) == 6 + for i in range(6): + assert count(flat[i]) == i + 1 + + if __name__ == "__main__": test_nat_constructor() test_double() From 609f56e148eb98f708f7e3d278b0a84e52664e94 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 23 Jan 2019 19:10:22 -0800 Subject: [PATCH 36/61] Allow incomplete types when typechecking match cases --- src/relay/pass/type_infer.cc | 11 ++++++++++- tests/python/relay/test_adt.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 52cdb6db3e20..4cd4a92e81e7 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -207,7 +207,16 @@ class TypeInferencer : private ExprFunctor, << "Cannot do type inference without a environment:" << con->con->name_hint; TypeData td = mod_->type_definitions.at(con->con->belong_to); - auto* tc = t.as(); + + // we can expect a certain number of arguments + Array unknown_args; + for (size_t i = 0; i < td->tv.size(); i++) { + unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); + } + Type expected = TypeCallNode::make(con->con->belong_to, unknown_args); + Type unified = Unify(t, expected, con->span); + + auto* tc = unified.as(); CHECK(tc) << "must be type call"; CHECK_EQ(td->header, tc->func); CHECK(td->tv.size() == tc->args.size()) << "both side must be equal"; diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 5d76d13746a9..15e5b8f71a01 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -242,7 +242,7 @@ def test_wildcard_match_order(): def test_nested_matches(): a = relay.TypeVar('a') - x = relay.Var('x', l(l(a))) + x = relay.Var('x') y = relay.Var('y') w = relay.Var('w') h = relay.Var('h') From 089813adbdfbb80d506529f7afbbc393e78b60c6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 24 Jan 2019 12:17:54 -0800 Subject: [PATCH 37/61] Type inference for pattern vars should use the type annotation if it's there --- src/relay/pass/type_infer.cc | 3 ++- tests/python/relay/test_type_infer.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4cd4a92e81e7..067e37897e52 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -231,7 +231,8 @@ class TypeInferencer : private ExprFunctor, } void VisitPattern_(const PatternVarNode* pv, const Type& t) { - type_map_[pv->var] = ResolvedTypeInfo(t, {}); + Type vt = GetType(pv->var); + Unify(vt, t, pv->span); } void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index b37dae3ab3ce..f476181922df 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -241,6 +241,26 @@ def test_adt_match(): assert mt.checked_type == relay.TupleType([]) +def test_adt_match_type_annotations(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + # the only type annotation is inside the match pattern var + # but that should be enough info + tt = relay.TensorType((2, 2), 'float32') + x = relay.Var('x') + mv = relay.Var('mv', tt) + match = relay.Match(constructor(x), + [relay.Clause( + relay.PatternConstructor(constructor, + [relay.PatternVar(mv)]), + relay.Tuple([]))]) + + func = relay.Function([x], match) + ft = relay.ir_pass.infer_type(func, mod) + assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) + + if __name__ == "__main__": test_free_expr() test_dual_op() From ebec99c64cb19f005528345bb1f8ca70acdf3baa Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 24 Jan 2019 12:28:51 -0800 Subject: [PATCH 38/61] Two more specific test cases for ADT matching --- tests/python/relay/test_adt.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 15e5b8f71a01..6c386fd77d26 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -281,6 +281,52 @@ def test_nested_matches(): assert count(flat[i]) == i + 1 +def test_match_full_var(): + x = relay.Var('x') + v = relay.Var('v') + id_func = relay.Function([x], + relay.Match(x, + [relay.Clause(relay.PatternVar(v), + v)])) + + res1 = intrp.evaluate(id_func(nil())) + res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil())))) + + empty = to_list(res1) + assert len(empty) == 0 + + zeroes = to_list(res2) + assert len(zeroes) == 2 + assert count(zeroes[0]) == 0 + assert count(zeroes[1]) == 0 + + +def test_nested_pattern_match(): + x = relay.Var('x', l(nat())) + h1 = relay.Var('h1') + h2 = relay.Var('h2') + t = relay.Var('t') + match = relay.Match( + x, + [relay.Clause( + relay.PatternConstructor( + cons, + [relay.PatternVar(h1), + relay.PatternConstructor( + cons, + [relay.PatternVar(h2), relay.PatternVar(t)])]), + h2), + relay.Clause(relay.PatternWildcard(), z()) + ]) + get_second = relay.Function([x], match) + + res = intrp.evaluate(get_second(cons(s(z()), + cons(s(s(z())), + nil())))) + + assert count(res) == 2 + + if __name__ == "__main__": test_nat_constructor() test_double() From 00963ded08a5aa6231dd049d3839bf259c36335d Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 24 Jan 2019 15:33:33 -0800 Subject: [PATCH 39/61] Add option ADT to prelude --- python/tvm/relay/prelude.py | 11 +++++++++++ tests/python/relay/test_adt.py | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 48fa559bbc14..1803e0485499 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -81,6 +81,15 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) + def define_option_adt(self): + """Defines an option ADT, which can either contain some other + type or nothing at all.""" + self.option = GlobalTypeVar("option") + a = TypeVar("a") + self.some = Constructor("some", [a], self.option) + self.none = Constructor("none", [], self.option) + self.mod[self.option] = TypeData(self.option, [a], [self.some, self.none]) + def define_nat_adt(self): """Defines a Peano (unary) natural number ADT. Zero is represented by z(). s(n) adds 1 to a nat n.""" @@ -177,6 +186,8 @@ def __init__(self, mod): self.define_list_foldl() self.define_list_foldr() + self.define_option_adt() + self.define_nat_adt() self.define_nat_double() self.define_nat_add() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 6c386fd77d26..fa742f9d6592 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -16,6 +16,9 @@ double = p.double add = p.add +some = p.some +none = p.none + nil = p.nil cons = p.cons l = p.l @@ -172,6 +175,27 @@ def test_sum(): assert count(res) == 3 +def test_option_matching(): + x = relay.Var('x') + y = relay.Var('y') + v = relay.Var('v') + condense = relay.Function( + [x, y], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)), + relay.Clause(relay.PatternConstructor(none), y) + ])) + + res = intrp.evaluate(foldr(condense, nil(), cons( + some(build_nat(3)), + cons(none(), cons(some(build_nat(1)), nil()))))) + + reduced = to_list(res) + assert len(reduced) == 2 + assert count(reduced[0]) == 3 + assert count(reduced[1]) == 1 + + def test_tmap(): a = relay.TypeVar("a") b = relay.TypeVar("b") From 47babdb2381f2fa5a47b6daee173fd241f944e80 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 24 Jan 2019 16:44:07 -0800 Subject: [PATCH 40/61] Fix broken reference to kind enum --- src/relay/pass/to_anf.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc index a724d5f2e855..6d65fe449fb0 100644 --- a/src/relay/pass/to_anf.cc +++ b/src/relay/pass/to_anf.cc @@ -274,7 +274,7 @@ class Fill : ExprFunctor { } Expr VisitExpr(const Expr& e) { - Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); + Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); return this->VisitExpr(e, v); } From a37d92789226d75d3cb15160a633f137843dce47 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 28 Jan 2019 14:16:13 -0800 Subject: [PATCH 41/61] Fix rebase snags --- src/relay/ir/module.cc | 2 +- src/relay/pass/type_infer.cc | 10 +++++----- src/relay/pass/type_solver.cc | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 920c15b2e756..da273265ae33 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -146,7 +146,7 @@ void ModuleNode::Update(const Module& mod) { Module ModuleNode::FromExpr( const Expr& expr, const tvm::Map& global_funcs) { - auto mod = ModuleNode::make(global_funcs); + auto mod = ModuleNode::make(global_funcs, {}); auto func_node = expr.as(); Function func; if (func_node) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 067e37897e52..054711efebd0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const Expr& expr) { + Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2, expr); @@ -148,7 +148,7 @@ class TypeInferencer : private ExprFunctor, return ret; } - void ReportFatalError(const Expr& expr, const Error& err) { + void ReportFatalError(const NodeRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); @@ -214,7 +214,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->con->belong_to, unknown_args); - Type unified = Unify(t, expected, con->span); + Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); CHECK(tc) << "must be type call"; @@ -259,7 +259,7 @@ class TypeInferencer : private ExprFunctor, // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType); } Type vtype = GetType(let->value); @@ -714,7 +714,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); - auto free_tvars = FreeTypeVars(e); + auto free_tvars = FreeTypeVars(e, mod_ref); CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 4749e8934b36..fd15c91e79f7 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -518,7 +518,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { Expr e = VarNode::make("dummy_var", - IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + IncompleteTypeNode::make(Kind::kType)); return solver->AddConstraint(c, e); }); } else { From 0c660aaec073f219b4e3aeca133b9a3132a7fd30 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 30 Jan 2019 22:35:45 -0800 Subject: [PATCH 42/61] Do not attach checked types to constructors --- src/relay/pass/type_infer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 054711efebd0..771b13d28e91 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -547,7 +547,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } Expr VisitExpr_(const ConstructorNode* op) final { - return AttachCheckedType(op); + return GetRef(op); } Expr VisitExpr_(const MatchNode* op) final { @@ -686,6 +686,7 @@ struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { if (e.as()) { return; } if (e.as()) { return; } + if (e.as()) { return; } CHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } From f5cec3eb65ac3e59ec1134a1ffe34af5400af05c Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 1 Feb 2019 15:24:15 -0800 Subject: [PATCH 43/61] More docstrings for module fields --- include/tvm/relay/module.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 988c7402a506..6de3b22f6566 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -43,6 +43,7 @@ class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; + /*! \brief A map from global type vars to ADT type data. */ tvm::Map type_definitions; /*! \brief The entry function (i.e. "main"). */ @@ -170,6 +171,10 @@ class ModuleNode : public RelayNode { * ensures global uniqueness. */ tvm::Map global_var_map_; + + /*! \brief A map from string names to global type variables (ADT names) + * that ensures global uniqueness. + */ tvm::Map global_type_var_map_; }; From b56bc3692ba55510430c37b4fb7741e3cd360880 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 1 Feb 2019 15:24:35 -0800 Subject: [PATCH 44/61] Use proper wrapper for indexing into module type data --- src/relay/pass/type_infer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 771b13d28e91..36b79d24c780 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -478,7 +478,7 @@ class TypeInferencer : private ExprFunctor, CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; - TypeData td = mod_->type_definitions.at(c->belong_to); + TypeData td = mod_->LookupDef(c->belong_to); std::vector types; for (const auto & t : td->tv) { types.push_back(t); From 6b8dbb85883c5b6a8728cf0aed0906d9200632be Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 1 Feb 2019 15:25:12 -0800 Subject: [PATCH 45/61] checked_type for constructors is not populated --- tests/python/relay/test_adt.py | 2 -- tests/python/relay/test_type_infer.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index fa742f9d6592..7fe48df14bba 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -87,7 +87,6 @@ def test_nat_value(): def test_nat_constructor(): assert relay.ir_pass.infer_type(z(), mod).checked_type == nat() - assert relay.ir_pass.infer_type(s, mod).checked_type == relay.FuncType([nat()], nat()) assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat() @@ -105,7 +104,6 @@ def test_add(): def test_list_constructor(): a = relay.TypeVar("a") - assert relay.ir_pass.infer_type(nil, mod).checked_type == relay.FuncType([], l(a), [a]) assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index f476181922df..05f8b8fd22f9 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -204,8 +204,10 @@ def test_constructor_type(): mod = relay.Module() box, constructor = initialize_box_adt(mod) - ct = relay.ir_pass.infer_type(constructor, mod) a = relay.TypeVar('a') + x = relay.Var('x', a) + ct = relay.ir_pass.infer_type( + relay.Function([x], constructor(x), box(a), [a]), mod) expected = relay.FuncType([a], box(a), [a]) assert ct.checked_type == expected From 07ea9152ae91f8591eb149b0604542c51f43f3cf Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 1 Feb 2019 15:45:58 -0800 Subject: [PATCH 46/61] Expand type call docstring --- python/tvm/relay/ty.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 7ac2360aa473..1cfa96aa7213 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -149,7 +149,9 @@ def __init__(self, var, kind=Kind.AdtHandle): @register_relay_node class TypeCall(Type): - """Type-level function application in Relay.""" + """Type-level function application in Relay. + A type call applies argument types to a constructor (type-level function). + """ def __init__(self, func, args): """Construct a TypeCall. From b7cfc59bb0194667d052400f3641eb6785dc864b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 1 Feb 2019 15:59:35 -0800 Subject: [PATCH 47/61] Rename PatternConstructor con field --- include/tvm/relay/adt.h | 6 +++--- src/relay/backend/interpreter.cc | 4 ++-- src/relay/ir/adt.cc | 7 ++++--- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/hash.cc | 2 +- src/relay/ir/pattern_functor.cc | 4 ++-- src/relay/ir/text_printer.cc | 2 +- src/relay/pass/type_infer.cc | 12 ++++++------ 8 files changed, 20 insertions(+), 19 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index d64f9614b42e..8f82e38482ff 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -123,15 +123,15 @@ class PatternConstructor; /*! \brief PatternVar container node */ class PatternConstructorNode : public PatternNode { public: - Constructor con; + Constructor constructor; tvm::Array pat; PatternConstructorNode() {} - TVM_DLL static PatternConstructor make(Constructor con, tvm::Array var); + TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array var); void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("con", &con); + v->Visit("constructor", &constructor); v->Visit("pat", &pat); v->Visit("span", &span); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 3a59387c7b4b..c71dede6a2ac 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -513,9 +513,9 @@ class Interpreter : bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { const ConValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; - CHECK_NE(op->con->tag, -1); + CHECK_NE(op->constructor->tag, -1); CHECK_NE(cvn->con->tag, -1); - if (op->con->tag == cvn->con->tag) { + if (op->constructor->tag == cvn->con->tag) { // todo(M.K.): should use ptr equality but it is broken CHECK(op->pat.size() == cvn->fields.size()); for (size_t i = 0; i < op->pat.size(); ++i) { diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 4bcbc5f27e4b..872deff75a19 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -46,9 +46,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "PatternVarNode(" << node->var << ")"; }); -PatternConstructor PatternConstructorNode::make(Constructor con, tvm::Array pat) { +PatternConstructor PatternConstructorNode::make(Constructor constructor, + tvm::Array pat) { NodePtr n = make_node(); - n->con = std::move(con); + n->constructor = std::move(constructor); n->pat = std::move(pat); return PatternConstructor(n); } @@ -63,7 +64,7 @@ TVM_REGISTER_API("relay._make.PatternConstructor") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PatternConstructorNode* node, tvm::IRPrinter* p) { - p->stream << "PatternConstructorNode(" << node->con + p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->pat << ")"; }); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 8ad0cc0c684a..e92700aacb0d 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -439,7 +439,7 @@ class AlphaEqualHandler: bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final { const auto* r = e2.as(); if (r == nullptr - || !ExprEqual(op->con, r->con) + || !ExprEqual(op->constructor, r->constructor) || op->pat.size() != r->pat.size()) { return false; } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 9c991e76ac60..4cccbba31ccf 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -359,7 +359,7 @@ class RelayHashHandler: size_t VisitPattern_(const PatternConstructorNode* pcn) final { size_t hash = std::hash()(PatternConstructorNode::_type_key); - hash = Combine(hash, ExprHash(pcn->con)); + hash = Combine(hash, ExprHash(pcn->constructor)); for (const auto& p : pcn->pat) { hash = Combine(hash, PatternHash(p)); } diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index 71002058fe49..d1b5a9049142 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -26,7 +26,7 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { for (const auto& p : op->pat) { pat.push_back(VisitPattern(p)); } - return PatternConstructorNode::make(VisitConstructor(op->con), pat); + return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); } Type PatternMutator::VisitType(const Type& t) { @@ -53,7 +53,7 @@ void PatternVisitor::VisitPattern_(const PatternVarNode* op) { } void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { - VisitConstructor(op->con); + VisitConstructor(op->constructor); for (const auto& p : op->pat) { VisitPattern(p); } diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 3cca40e605e6..d3f354be6458 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -417,7 +417,7 @@ class TextPrinter : } TextValue VisitPattern_(const PatternConstructorNode* p) final { - TextValue ret(p->con->name_hint + "("); + TextValue ret(p->constructor->name_hint + "("); for (const Pattern& pat : p->pat) { ret = ret + " " + GetValue(pat); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 36b79d24c780..3c6126db42e7 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -205,15 +205,15 @@ class TypeInferencer : private ExprFunctor, void VisitPattern_(const PatternConstructorNode* con, const Type& t) { CHECK(mod_.defined()) << "Cannot do type inference without a environment:" - << con->con->name_hint; - TypeData td = mod_->type_definitions.at(con->con->belong_to); + << con->constructor->name_hint; + TypeData td = mod_->type_definitions.at(con->constructor->belong_to); // we can expect a certain number of arguments Array unknown_args; for (size_t i = 0; i < td->tv.size(); i++) { unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } - Type expected = TypeCallNode::make(con->con->belong_to, unknown_args); + Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); @@ -224,9 +224,9 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < td->tv.size(); ++i) { type_var_map_[td->tv[i]] = tc->args[i]; } - CHECK(con->con->inp.size() == con->pat.size()) << "not enough pattern"; - for (size_t i = 0; i < con->con->inp.size(); ++i) { - VisitPattern(con->pat[i], Bind(con->con->inp[i], type_var_map_)); + CHECK(con->constructor->inp.size() == con->pat.size()) << "not enough pattern"; + for (size_t i = 0; i < con->constructor->inp.size(); ++i) { + VisitPattern(con->pat[i], Bind(con->constructor->inp[i], type_var_map_)); } } From 8cd15f239c1726bae996c44bd296f2868cb35b7c Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 4 Feb 2019 11:08:23 -0800 Subject: [PATCH 48/61] Use error reporter for pattern constructor case --- src/relay/pass/type_infer.cc | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3c6126db42e7..e80ad221e9b0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -207,6 +207,7 @@ class TypeInferencer : private ExprFunctor, << "Cannot do type inference without a environment:" << con->constructor->name_hint; TypeData td = mod_->type_definitions.at(con->constructor->belong_to); + auto pc = GetRef(con); // we can expect a certain number of arguments Array unknown_args; @@ -217,14 +218,28 @@ class TypeInferencer : private ExprFunctor, Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); - CHECK(tc) << "must be type call"; - CHECK_EQ(td->header, tc->func); - CHECK(td->tv.size() == tc->args.size()) << "both side must be equal"; + if (!tc) { + this->ReportFatalError(pc, RELAY_ERROR("Expected a type call, got " << unified)); + } + if (td->header != tc->func) { + this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have " + << td->header << " and " << tc->func)); + } + if (td->tv.size() != tc->args.size()) { + this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match" + << "the number of type vars in the type data: " + << td->tv.size() << " != " << tc->args.size())); + } std::unordered_map type_var_map_; for (size_t i = 0; i < td->tv.size(); ++i) { type_var_map_[td->tv[i]] = tc->args[i]; } CHECK(con->constructor->inp.size() == con->pat.size()) << "not enough pattern"; + if (con->constructor->inp.size() != con->pat.size()) { + this->ReportFatalError(pc, RELAY_ERROR("Not enough inputs for the constructor; " + << "expected " << con->constructor->inp.size() + << ", got " << con->pat.size())); + } for (size_t i = 0; i < con->constructor->inp.size(); ++i) { VisitPattern(con->pat[i], Bind(con->constructor->inp[i], type_var_map_)); } From 1d9ae484562851beab94f2a4ee6ce9cb271fa7f2 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 4 Feb 2019 11:56:51 -0800 Subject: [PATCH 49/61] Condense error reporting in kind check, use error reporter --- src/relay/pass/kind_check.cc | 97 ++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index effb4b2f5c54..e518d7008bf4 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -14,6 +14,7 @@ * contains a data type such as `int`, `float`, `uint`. */ #include +#include #include "../ir/type_functor.h" namespace tvm { @@ -23,8 +24,26 @@ using namespace tvm::runtime; struct KindChecker : TypeFunctor { const Module& mod; + ErrorReporter err_reporter; - explicit KindChecker(const Module& mod) : mod(mod) {} + explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {} + + void ReportFatalError(const Error& err) { + this->err_reporter.Report(err); + this->err_reporter.RenderErrors(mod); + } + + void CheckKindMatches(const Type& t, const Type& outer, + Kind expected, const std::string& description) { + Kind k = this->VisitType(t); + if (k != expected) { + ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description + << ". Type " << t << " inside " << outer + << " is of kind " << k + << " but was expected to be " + << expected)); + } + } Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; @@ -45,10 +64,8 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - Kind k = this->VisitType(t); - CHECK(k == Kind::kType) - << "All types in tuple type must be of a type kind but " - << t << " in " << GetRef(op) << " is of kind " << k; + CheckKindMatches(t, GetRef(op), Kind::kType, + "tuple member"); } return Kind::kType; } @@ -59,22 +76,13 @@ struct KindChecker : TypeFunctor { // well-formed constraints FuncType ft = GetRef(op); for (const Type& t : op->arg_types) { - Kind k = this->VisitType(t); - CHECK(k == Kind::kType) - << "Function parameters must be of the type kind but parameter " - << t << " of " << ft << " is of kind " << k; + CheckKindMatches(t, ft, Kind::kType, "function type parameter"); } - Kind ret_kind = this->VisitType(ft->ret_type); - CHECK(ret_kind == Kind::kType) - << "The function return type must be of the type kind but " - << ft->ret_type << " of " << ft << " is of kind " << ret_kind; + CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type"); for (const TypeConstraint& tc : op->type_constraints) { - Kind k = this->VisitType(tc); - CHECK(k == Kind::kConstraint) - << "All function type constraints are of the constraint kind but " - << tc << " of " << ft << " is of kind " << k; + CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint"); } return Kind::kType; @@ -92,10 +100,8 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - Kind k = this->VisitType(t); - CHECK(k == Kind::kType) - << "All arguments to type relations must be of the type kind but " - << t << " of " << GetRef(op) << " is of kind " << k; + CheckKindMatches(t, GetRef(op), Kind::kType, + "argument to type relation"); } return Kind::kConstraint; } @@ -104,28 +110,24 @@ struct KindChecker : TypeFunctor { // type call func should be a global type var, args should be type TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); - CHECK(gtv != nullptr) - << "Type call must be calling a global type var"; + if (gtv == nullptr) { + ReportFatalError(RELAY_ERROR("The callee in " << tc + << " is not a global type var, but is " << op->func)); + } - Kind func_kind = this->VisitType(op->func); - CHECK(func_kind == Kind::kAdtHandle) - << "Type calls must call a global type var that is an ADT handle but " - << op->func << " of " << tc << " is of kind " << func_kind; + CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); for (const Type& t : op->args) { - Kind k = this->VisitType(t); - CHECK(k == Kind::kType) - << "Type call arguments must be of the type kind but " - << t << " of " << tc << " is of kind " << k; + CheckKindMatches(t, tc, Kind::kType, "type call argument"); } // finally we need to check the module to check the number of type params auto var = GetRef(gtv); auto data = mod->LookupDef(var); - CHECK(data->tv.size() == op->args.size()) - << "Incorrect arity in " << tc - << " Expected: " << data->tv.size() - << " Given: " << op->args.size(); + if (data->tv.size() != op->args.size()) { + ReportFatalError(RELAY_ERROR("Expected " << data->tv.size() << "arguments for " << tc + << "; got " << op->args.size())); + } return Kind::kType; } @@ -135,31 +137,20 @@ struct KindChecker : TypeFunctor { // should be tracked recursively, but it is unclear that we need // to support it. TypeData td = GetRef(op); - Kind header_kind = this->VisitType(op->header); - CHECK(header_kind == Kind::kAdtHandle) - << "The header for ADT type data must be an ADT handle but " - << op->header << " of " << td << " is of kind " << header_kind; + CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header"); for (const auto& var : op->tv) { - Kind k = this->VisitType(var); - CHECK(k == Kind::kType) - << "All type params for ADT type data must be of the type kind but " - << var << " of " << td << " is of kind " << k; + CheckKindMatches(var, td, Kind::kType, "ADT type var"); } for (const auto& con : op->constructors) { - CHECK(con->belong_to.same_as(op->header)) - << "Constructors should have same global type var as type data"; + if (!con->belong_to.same_as(op->header)) { + ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to + << " but " << op << "has header " << op->header)); + } for (const Type& t : con->inp) { - Kind k = this->VisitType(t); - CHECK(k == Kind::kType) - << "All inputs to a constructor must be of the type kind but" - << t << " of " << con << " is of kind " << k; - if (const auto* gtv = t.as()) { - CHECK(GetRef(gtv).same_as(op->header)) - << "A global type var taken by a constructor must be the one the constructor makes"; - } + CheckKindMatches(t, td, Kind::kType, "ADT constructor input"); } } return Kind::kTypeData; From acc2ec025cd0bb2d97bed6f58e9f34bafba21c2d Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 4 Feb 2019 12:20:20 -0800 Subject: [PATCH 50/61] Expand docstrings and rename ADT fields --- include/tvm/relay/adt.h | 17 ++++++++++++----- python/tvm/relay/adt.py | 12 ++++++------ src/relay/backend/interpreter.cc | 2 +- src/relay/ir/adt.cc | 12 ++++++------ src/relay/ir/alpha_equal.cc | 6 +++--- src/relay/ir/expr_functor.cc | 10 +++++----- src/relay/ir/hash.cc | 4 ++-- src/relay/ir/text_printer.cc | 2 +- src/relay/ir/type_functor.cc | 2 +- src/relay/pass/fuse_ops.cc | 2 +- src/relay/pass/kind_check.cc | 6 +++--- src/relay/pass/type_infer.cc | 18 +++++++++--------- src/relay/pass/util.cc | 2 +- 13 files changed, 51 insertions(+), 44 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 8f82e38482ff..4f5c7a995894 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -29,7 +29,7 @@ class PatternNode : public RelayNode { * Given an ADT value, a pattern might accept it and bind the pattern variable to some value * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. * - * ADT pattern matching thus takes a list of values and bings to the first that accepts the value. + * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. */ class Pattern : public NodeRef { public: @@ -65,6 +65,7 @@ class PatternVarNode : public PatternNode { public: PatternVarNode() {} + /*! \brief Variable that stores the matched value. */ tvm::relay::Var var; TVM_DLL static PatternVar make(tvm::relay::Var var); @@ -123,7 +124,9 @@ class PatternConstructor; /*! \brief PatternVar container node */ class PatternConstructorNode : public PatternNode { public: + /*! Constructor matched by the pattern. */ Constructor constructor; + /*! Sub-patterns to match against each input to the constructor. */ tvm::Array pat; PatternConstructorNode() {} @@ -165,13 +168,13 @@ class TypeDataNode : public TypeNode { */ GlobalTypeVar header; /*! \brief The type variables (to allow for polymorphism). */ - tvm::Array tv; + tvm::Array ty_vars; /*! \brief The constructors. */ tvm::Array constructors; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("header", &header); - v->Visit("tv", &tv); + v->Visit("ty_vars", &ty_vars); v->Visit("constructors", &constructors); v->Visit("span", &span); } @@ -191,7 +194,9 @@ class Clause; /*! \brief Clause container node. */ class ClauseNode : public Node { public: + /*! \brief The pattern the clause matches. */ Pattern lhs; + /*! \brief The resulting value. */ Expr rhs; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -212,13 +217,15 @@ class Match; /*! \brief Match container node. */ class MatchNode : public ExprNode { public: + /*! \brief The input being deconstructed. */ Expr data; - tvm::Array pattern; + /*! \brief The match node clauses. */ + tvm::Array clauses; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); - v->Visit("pattern", &pattern); + v->Visit("clause", &clauses); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index ad276e0c9fb1..5bcfbc6817fd 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -122,7 +122,7 @@ class TypeData(Type): type call that passes in the type params. """ - def __init__(self, header, tv, constructors): + def __init__(self, header, ty_vars, constructors): """Defines a TypeData object. Parameters @@ -131,7 +131,7 @@ def __init__(self, header, tv, constructors): The name of the ADT. ADTs with the same constructors but different names are treated as different types. - tv: List[TypeVar] + ty_vars: List[TypeVar] Type variables that appear in constructors. constructors: List[tvm.relay.Constructor] The constructors for the ADT. @@ -141,7 +141,7 @@ def __init__(self, header, tv, constructors): type_data: TypeData The adt declaration. """ - self.__init_handle_by_constructor__(_make.TypeData, header, tv, constructors) + self.__init_handle_by_constructor__(_make.TypeData, header, ty_vars, constructors) @register_relay_node @@ -170,18 +170,18 @@ def __init__(self, lhs, rhs): class Match(Expr): """Pattern matching expression in Relay.""" - def __init__(self, data, pattern): + def __init__(self, data, clauses): """Construct a Match. Parameters ---------- data: tvm.relay.Expr The value being deconstructed and matched. - pattern: [tvm.relay.Clause] + clauses: List[tvm.relay.Clause] The pattern match clauses. Returns ------- match: tvm.relay.Expr The match expression. """ - self.__init_handle_by_constructor__(_make.Match, data, pattern) + self.__init_handle_by_constructor__(_make.Match, data, clauses) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index c71dede6a2ac..43f17e27a839 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -501,7 +501,7 @@ class Interpreter : Value VisitExpr_(const MatchNode* op) final { Value v = Eval(op->data); - for (const Clause& c : op->pattern) { + for (const Clause& c : op->clauses) { if (VisitPattern(c->lhs, v)) { return VisitExpr(c->rhs); } diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 872deff75a19..8ea3abe02d0e 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -93,11 +93,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeData TypeDataNode::make(GlobalTypeVar header, - tvm::Array tv, + tvm::Array ty_vars, tvm::Array constructors) { NodePtr n = make_node(); n->header = std::move(header); - n->tv = std::move(tv); + n->ty_vars = std::move(ty_vars); n->constructors = std::move(constructors); return TypeData(n); } @@ -112,7 +112,7 @@ TVM_REGISTER_API("relay._make.TypeData") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeDataNode* node, tvm::IRPrinter* p) { - p->stream << "TypeDataNode(" << node->header << ", " << node->tv << ", " + p->stream << "TypeDataNode(" << node->header << ", " << node->ty_vars << ", " << node->constructors << ")"; }); @@ -137,10 +137,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->rhs << ")"; }); -Match MatchNode::make(Expr data, tvm::Array pattern) { +Match MatchNode::make(Expr data, tvm::Array clauses) { NodePtr n = make_node(); n->data = std::move(data); - n->pattern = std::move(pattern); + n->clauses = std::move(clauses); return Match(n); } @@ -155,7 +155,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const MatchNode* node, tvm::IRPrinter* p) { p->stream << "MatchNode(" << node->data << ", " - << node->pattern << ")"; + << node->clauses << ")"; }); } // namespace relay diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index e92700aacb0d..8e27e65f5d8b 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -457,12 +457,12 @@ class AlphaEqualHandler: if (r == nullptr || !ExprEqual(op->data, r->data) - || op->pattern.size() != r->pattern.size()) { + || op->clauses.size() != r->clauses.size()) { return false; } - for (size_t i = 0; i < op->pattern.size(); ++i) { - if (!ClauseEqual(op->pattern[i], r->pattern[i])) { + for (size_t i = 0; i < op->clauses.size(); ++i) { + if (!ClauseEqual(op->clauses[i], r->clauses[i])) { return false; } } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c89c40dc1461..ae65d21c614b 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -190,11 +190,11 @@ Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { } Expr ExprMutator::VisitExpr_(const MatchNode* m) { - std::vector pattern; - for (const Clause& p : m->pattern) { - pattern.push_back(VisitClause(p)); + std::vector clauses; + for (const Clause& p : m->clauses) { + clauses.push_back(VisitClause(p)); } - return MatchNode::make(VisitExpr(m->data), pattern); + return MatchNode::make(VisitExpr(m->data), clauses); } Clause ExprMutator::VisitClause(const Clause& c) { @@ -294,7 +294,7 @@ void ExprVisitor::VisitExpr_(const ConstructorNode* op) { void ExprVisitor::VisitExpr_(const MatchNode* op) { this->VisitExpr(op->data); - for (const Clause& c : op->pattern) { + for (const Clause& c : op->clauses) { this->VisitClause(c); } } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 4cccbba31ccf..798e3f4e20ab 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -315,7 +315,7 @@ class RelayHashHandler: size_t VisitExpr_(const MatchNode* mn) final { size_t hash = std::hash()(MatchNode::_type_key); hash = Combine(hash, ExprHash(mn->data)); - for (const auto& c : mn->pattern) { + for (const auto& c : mn->clauses) { hash = Combine(hash, PatternHash(c->lhs)); hash = Combine(hash, ExprHash(c->rhs)); } @@ -340,7 +340,7 @@ class RelayHashHandler: size_t VisitType_(const TypeDataNode* tdn) final { size_t hash = std::hash()(TypeDataNode::_type_key); hash = Combine(hash, TypeHash(tdn->header)); - for (const auto& tv : tdn->tv) { + for (const auto& tv : tdn->ty_vars) { hash = Combine(hash, TypeHash(tv)); } for (const auto& cn : tdn->constructors) { diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index d3f354be6458..c00e0ede3f06 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -408,7 +408,7 @@ class TextPrinter : TextValue id = this->AllocTempVar(); stream_ << id << " = " << "Match " << data << " with"; this->PrintEndInst("\n"); - for (const auto& c : op->pattern) { + for (const auto& c : op->clauses) { this->PrintIndent(); stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); this->PrintEndInst("\n"); diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 49fd905e2297..40df5dda4f22 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -60,7 +60,7 @@ void TypeVisitor::VisitType_(const TypeCallNode* op) { void TypeVisitor::VisitType_(const TypeDataNode* op) { this->VisitType(op->header); - for (const auto& v : op->tv) { + for (const auto& v : op->ty_vars) { this->VisitType(v); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 0012a117a56c..169aef3b6a4a 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -299,7 +299,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const MatchNode* op) final { this->Update(op->data, nullptr, kOpaque); - for (const Clause& c : op->pattern) { + for (const Clause& c : op->clauses) { this->Update(c->rhs, nullptr, kOpaque); } ExprVisitor::VisitExpr_(op); diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index e518d7008bf4..da9d07350c5e 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -124,8 +124,8 @@ struct KindChecker : TypeFunctor { // finally we need to check the module to check the number of type params auto var = GetRef(gtv); auto data = mod->LookupDef(var); - if (data->tv.size() != op->args.size()) { - ReportFatalError(RELAY_ERROR("Expected " << data->tv.size() << "arguments for " << tc + if (data->ty_vars.size() != op->args.size()) { + ReportFatalError(RELAY_ERROR("Expected " << data->ty_vars.size() << "arguments for " << tc << "; got " << op->args.size())); } return Kind::kType; @@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor { TypeData td = GetRef(op); CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header"); - for (const auto& var : op->tv) { + for (const auto& var : op->ty_vars) { CheckKindMatches(var, td, Kind::kType, "ADT type var"); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index e80ad221e9b0..ab7c0b92c03c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -211,7 +211,7 @@ class TypeInferencer : private ExprFunctor, // we can expect a certain number of arguments Array unknown_args; - for (size_t i = 0; i < td->tv.size(); i++) { + for (size_t i = 0; i < td->ty_vars.size(); i++) { unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); @@ -225,14 +225,14 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have " << td->header << " and " << tc->func)); } - if (td->tv.size() != tc->args.size()) { + if (td->ty_vars.size() != tc->args.size()) { this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match" << "the number of type vars in the type data: " - << td->tv.size() << " != " << tc->args.size())); + << td->ty_vars.size() << " != " << tc->args.size())); } std::unordered_map type_var_map_; - for (size_t i = 0; i < td->tv.size(); ++i) { - type_var_map_[td->tv[i]] = tc->args[i]; + for (size_t i = 0; i < td->ty_vars.size(); ++i) { + type_var_map_[td->ty_vars[i]] = tc->args[i]; } CHECK(con->constructor->inp.size() == con->pat.size()) << "not enough pattern"; if (con->constructor->inp.size() != con->pat.size()) { @@ -254,11 +254,11 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const MatchNode* op) final { Type dtype = GetType(op->data); - for (const auto& c : op->pattern) { + for (const auto& c : op->clauses) { VisitPattern(c->lhs, dtype); } Type rtype = IncompleteTypeNode::make(Kind::kType); - for (const auto& c : op->pattern) { + for (const auto& c : op->clauses) { rtype = this->Unify(rtype, GetType(c->rhs), op->span); @@ -495,10 +495,10 @@ class TypeInferencer : private ExprFunctor, << c->name_hint; TypeData td = mod_->LookupDef(c->belong_to); std::vector types; - for (const auto & t : td->tv) { + for (const auto & t : td->ty_vars) { types.push_back(t); } - return FuncTypeNode::make(c->inp, TypeCallNode::make(c->belong_to, types), td->tv, {}); + return FuncTypeNode::make(c->inp, TypeCallNode::make(c->belong_to, types), td->ty_vars, {}); } }; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 33efac9414fa..ce49aa84e91c 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -121,7 +121,7 @@ class TypeVarEVisitor : private ExprVisitor { void VisitExpr_(const ConstructorNode* cn) final { // for constructors, type vars will be bound in the module auto data = mod_->LookupDef(cn->belong_to); - for (const auto& tv : data->tv) { + for (const auto& tv : data->ty_vars) { type_vars_.Insert(tv); bound_type_vars_.Insert(tv); } From 737514b0ad012ca9b78fb9a7dc3775c1084d5ecf Mon Sep 17 00:00:00 2001 From: slyubomirsky Date: Mon, 11 Feb 2019 15:00:54 -0800 Subject: [PATCH 51/61] Rename 'option' ADT to 'optional' for consistency with Python --- python/tvm/relay/prelude.py | 22 +++++++++++----------- tests/python/relay/test_adt.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 1803e0485499..c4503755ad5b 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -23,7 +23,7 @@ def define_list_map(self): elements. That is, map(f, l) returns a new list where the ith member is f applied to the ith member of l. - map(f, l) : fun(fun(a) -> b, list) -> list + map(f, l) : fn(fn(a) -> b, list[a]) -> list[b] """ self.map = GlobalVar("map") a = TypeVar("a") @@ -40,7 +40,7 @@ def define_list_map(self): def define_list_foldl(self): """Defines a left-way fold over a list. - foldl(f, z, l) : fun(fun(b, a) -> b, b, list) -> b + foldl(f, z, l) : fn(fn(b, a) -> b, b, list[a]) -> b foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) evaluates to f(...f(f(f(z, a1), a2), a3)...) @@ -62,7 +62,7 @@ def define_list_foldl(self): def define_list_foldr(self): """Defines a right-way fold over a list. - foldr(f, l, z) : fun(fun(a, b) -> b, list, b) -> b + foldr(f, l, z) : fn(fn(a, b) -> b, list[a], b) -> b foldr(f, cons(a1, cons(a2, cons(..., cons(an, nil)))), z) evalutes to f(a1, f(a2, f(..., f(an, z)))...) @@ -81,14 +81,14 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) - def define_option_adt(self): - """Defines an option ADT, which can either contain some other + def define_optional_adt(self): + """Defines an optional ADT, which can either contain some other type or nothing at all.""" - self.option = GlobalTypeVar("option") + self.optional = GlobalTypeVar("optional") a = TypeVar("a") - self.some = Constructor("some", [a], self.option) - self.none = Constructor("none", [], self.option) - self.mod[self.option] = TypeData(self.option, [a], [self.some, self.none]) + self.some = Constructor("some", [a], self.optional) + self.none = Constructor("none", [], self.optional) + self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) def define_nat_adt(self): """Defines a Peano (unary) natural number ADT. @@ -152,7 +152,7 @@ def define_tree_map(self): """Defines a function that maps over a tree. The function is applied to each subtree's contents. - Signature: fun(f : fun(a) -> b, t : tree) -> tree + Signature: fn(f : fn(a) -> b, t : tree[a]) -> tree[b] """ self.tmap = GlobalVar("tmap") a = TypeVar("a") @@ -186,7 +186,7 @@ def __init__(self, mod): self.define_list_foldl() self.define_list_foldr() - self.define_option_adt() + self.define_optional_adt() self.define_nat_adt() self.define_nat_double() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 7fe48df14bba..0189d7328b21 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -173,7 +173,7 @@ def test_sum(): assert count(res) == 3 -def test_option_matching(): +def test_optional_matching(): x = relay.Var('x') y = relay.Var('y') v = relay.Var('v') From 1a6e48a4aa04a717dd26132e3d72b9ac37dfe162 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 14 Feb 2019 16:32:34 -0800 Subject: [PATCH 52/61] Add various list iterators and utility functions to prelude --- python/tvm/relay/prelude.py | 176 ++++++++++++++++++++++++++++++++++-- 1 file changed, 170 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index c4503755ad5b..4129f5642094 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -1,7 +1,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Adds certain standard global functions and ADT definitions to the module.""" -from .ty import GlobalTypeVar, TypeVar, FuncType -from .expr import Var, Function, GlobalVar +from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type +from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard @@ -35,12 +35,12 @@ def define_list_map(self): nil_case = Clause(PatternConstructor(self.nil), self.nil()) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), self.cons(f(y), self.map(f, z))) - self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), None, [a, b]) + self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) def define_list_foldl(self): """Defines a left-way fold over a list. - foldl(f, z, l) : fn(fn(b, a) -> b, b, list[a]) -> b + foldl(f, z, l) : fn(fn(a, b) -> a, a, list[a]) -> a foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) evaluates to f(...f(f(f(z, a1), a2), a3)...) @@ -57,7 +57,7 @@ def define_list_foldl(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), self.foldl(f, f(av, y), z)) self.mod[self.foldl] = Function([f, av, bv], - Match(bv, [nil_case, cons_case]), None, [a, b]) + Match(bv, [nil_case, cons_case]), a, [a, b]) def define_list_foldr(self): """Defines a right-way fold over a list. @@ -79,7 +79,123 @@ def define_list_foldr(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), f(y, self.foldr(f, bv, z))) self.mod[self.foldr] = Function([f, bv, av], - Match(av, [nil_case, cons_case]), None, [a, b]) + Match(av, [nil_case, cons_case]), b, [a, b]) + + def define_list_filter(self): + """Defines a function that filters a list. + + filter(f, l) : fn(fn(a) -> Tensor[(), bool], list[a]) -> list[a] + + It returns a the sublist of l consisting of the elements for which f returns true. + """ + self.filter = GlobalVar("filter") + a = TypeVar("a") + f = Var("f", FuncType([a], scalar_type("bool"))) + l = Var("l", self.l(a)) + h = Var("h") + t = Var("t") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h), PatternVar(t)]), + If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) + self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) + + def define_list_zip(self): + """Defines a function that combines two lists into a list of tuples of their elements. + + zip(l, m) : fn(list[a], list[b]) -> list[(a, b)] + + The zipped list will be the length of the shorter list. + """ + self.zip = GlobalVar("zip") + a = TypeVar("a") + b = TypeVar("b") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + l1 = Var("l1") + l2 = Var("l2") + h1 = Var("h1") + h2 = Var("h2") + t1 = Var("t1") + t2 = Var("t2") + inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]), + self.cons(Tuple([h1, h2]), self.zip(t1, t2))) + outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]), + Match(l2, [nil_case, inner_cons_case])) + self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), + self.l(TupleType([a, b])), [a, b]) + + def define_list_rev(self): + """Defines a function that reverses a list. + + rev(l) : fn(list[a]) -> list[a] + """ + self.rev = GlobalVar("rev") + a = TypeVar("a") + l = Var("l", self.l(a)) + x = Var("x") + y = Var("y") + updater = Function([y, x], self.cons(x, y)) + self.mod[self.rev] = Function([l], + self.foldl(updater, self.nil(), l), + self.l(a), [a]) + + def define_list_map_accumr(self): + """Defines an accumulative map, which is a fold that simulataneously updates + an accumulator value and a list of results. + + map_accumr(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) + + This map proceeds through l from right to left. + """ + self.map_accumr = GlobalVar("map_accumr") + a = TypeVar("a") + b = TypeVar("b") + c = TypeVar("c") + f = Var("f", FuncType([a, b], TupleType([a, c]))) + acc = Var("acc", a) + l = Var("l", self.l(b)) + v = Var("v", b) + p = Var("p", TupleType([a, self.l(c)])) + f_out = Var("f_out", TupleType([a, c])) + updater = Function([v, p], + Let(f_out, f(TupleGetItem(p, 0), v), + Tuple([TupleGetItem(f_out, 0), + self.cons(TupleGetItem(f_out, 1), + TupleGetItem(p, 1))])), + TupleType([a, self.l(c)])) + self.mod[self.map_accumr] = Function([f, acc, l], + self.foldr(updater, Tuple([acc, self.nil()]), l), + TupleType([a, self.l(c)]), + [a, b, c]) + + def define_list_map_accuml(self): + """Defines an accumulative map, which is a fold that simulataneously updates + an accumulator value and a list of results. + + map_accuml(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) + + This map proceeds through l from left to right. + """ + self.map_accuml = GlobalVar("map_accuml") + a = TypeVar("a") + b = TypeVar("b") + c = TypeVar("c") + f = Var("f", FuncType([a, b], TupleType([a, c]))) + acc = Var("acc", a) + l = Var("l", self.l(b)) + v = Var("v", b) + p = Var("p", TupleType([a, self.l(c)])) + f_out = Var("f_out", TupleType([a, c])) + updater = Function([p, v], + Let(f_out, f(TupleGetItem(p, 0), v), + Tuple([TupleGetItem(f_out, 0), + self.cons(TupleGetItem(f_out, 1), + TupleGetItem(p, 1))])), + TupleType([a, self.l(c)])) + self.mod[self.map_accuml] = Function([f, acc, l], + self.foldl(updater, Tuple([acc, self.nil()]), l), + TupleType([a, self.l(c)]), + [a, b, c]) + def define_optional_adt(self): """Defines an optional ADT, which can either contain some other @@ -90,6 +206,47 @@ def define_optional_adt(self): self.none = Constructor("none", [], self.optional) self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) + def define_list_unfoldr(self): + """Defines a function that builds up a list starting from a seed value. + + unfoldr(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] + + f returns an option containing a new seed and an output value. f will + continue to be called on the new seeds until it returns None. All the + output values will be combined into a list, right to left. + """ + self.unfoldr = GlobalVar("unfoldr") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) + s = Var("s", a) + p = Var("p", TupleType([a, b])) + none_case = Clause(PatternConstructor(self.none), self.nil()) + some_case = Clause(PatternConstructor(self.some, [PatternVar(p)]), + self.cons(TupleGetItem(p, 1), + self.unfoldr(f, TupleGetItem(p, 0)))) + self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), + self.l(b), [a, b]) + + def define_list_unfoldl(self): + """Defines a function that builds up a list starting from a seed value. + + unfoldl(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] + + f returns an option containing a new seed and an output value. f will + continue to be called on the new seeds until it returns None. All the + output values will be combined into a list, left to right. + """ + self.unfoldl = GlobalVar("unfoldl") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) + s = Var("s", a) + # easiest way to implement is to do a right unfold and reverse + self.mod[self.unfoldl] = Function([f, s], + self.rev(self.unfoldr(f, s)), + self.l(b), [a, b]) + def define_nat_adt(self): """Defines a Peano (unary) natural number ADT. Zero is represented by z(). s(n) adds 1 to a nat n.""" @@ -185,8 +342,15 @@ def __init__(self, mod): self.define_list_map() self.define_list_foldl() self.define_list_foldr() + self.define_list_filter() + self.define_list_zip() + self.define_list_rev() + self.define_list_map_accumr() + self.define_list_map_accuml() self.define_optional_adt() + self.define_list_unfoldr() + self.define_list_unfoldl() self.define_nat_adt() self.define_nat_double() From 511a9312ee25489b448fb9f548f8abf75e875a94 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 14 Feb 2019 17:50:19 -0800 Subject: [PATCH 53/61] Add smoke tests for new iterators in prelude --- tests/python/relay/test_adt.py | 220 ++++++++++++++++++++++++++++++++- 1 file changed, 216 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 0189d7328b21..3036088a45b3 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -16,6 +16,7 @@ double = p.double add = p.add +optional = p.optional some = p.some none = p.none @@ -28,6 +29,14 @@ foldr = p.foldr sum = p.sum +filter = p.filter +zip = p.zip +rev = p.rev +unfoldl = p.unfoldl +unfoldr = p.unfoldr +map_accumr = p.map_accumr +map_accuml = p.map_accuml + tree = p.tree rose = p.rose tmap = p.tmap @@ -138,14 +147,16 @@ def test_foldl(): x = relay.Var("x") y = relay.Var("y") - rev = relay.Function([y, x], cons(x, y)) - res = intrp.evaluate(foldl(rev, nil(), + rev_dup = relay.Function([y, x], cons(x, cons(x, y))) + res = intrp.evaluate(foldl(rev_dup, nil(), cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))))) reversed = to_list(res) - assert len(reversed) == 3 - assert count(reversed[0]) == 3 and count(reversed[1]) == 2 and count(reversed[2]) == 1 + assert len(reversed) == 6 + assert count(reversed[0]) == 3 and count(reversed[1]) == 3 + assert count(reversed[2]) == 2 and count(reversed[3]) == 2 + assert count(reversed[4]) == 1 and count(reversed[5]) == 1 def test_foldr(): @@ -173,6 +184,207 @@ def test_sum(): assert count(res) == 3 +def test_filter(): + a = relay.TypeVar("a") + expected_type = relay.FuncType([ + relay.FuncType([a], relay.scalar_type("bool")), l(a) + ], l(a), [a]) + assert mod[filter].checked_type == expected_type + + x = relay.Var("x", nat()) + greater_than_one = relay.Function( + [x], + relay.Match(x, [ + relay.Clause( + relay.PatternConstructor(s, [ + relay.PatternConstructor( + s, [relay.PatternWildcard()]) + ]), + relay.const(True)), + relay.Clause(relay.PatternWildcard(), relay.const(False)) + ])) + res = intrp.evaluate( + filter(greater_than_one, + cons(build_nat(1), + cons(build_nat(1), + cons(build_nat(3), + cons(build_nat(1), + cons(build_nat(5), + cons(build_nat(1), + nil())))))))) + filtered = to_list(res) + assert len(filtered) == 2 + assert count(filtered[0]) == 3 + assert count(filtered[1]) == 5 + + +def test_zip(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([l(a), l(b)], + l(relay.TupleType([a, b])), [a, b]) + assert mod[zip].checked_type == expected_type + + l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + l2 = cons(nil(), + cons(cons(nil(), nil()), + cons(cons(nil(), cons(nil(), nil())), + nil()))) + + res = intrp.evaluate(zip(l1, l2)) + zipped = to_list(res) + assert len(zipped) == 3 + assert count(zipped[0][0]) == 1 + assert len(to_list(zipped[0][1])) == 0 + assert count(zipped[1][0]) == 2 + assert len(to_list(zipped[1][1])) == 1 + assert count(zipped[2][0]) == 3 + assert len(to_list(zipped[2][1])) == 2 + + # test truncation + l3 = cons(build_nat(4), cons(build_nat(5), nil())) + shorter_res = intrp.evaluate(zip(l3, l2)) + truncated = to_list(shorter_res) + assert len(truncated) == 2 + assert count(truncated[0][0]) == 4 + assert len(to_list(truncated[0][1])) == 0 + assert count(truncated[1][0]) == 5 + assert len(to_list(truncated[1][1])) == 1 + + l4 = cons(nil(), nil()) + shortest_res = intrp.evaluate(zip(l3, l4)) + singleton = to_list(shortest_res) + assert len(singleton) == 1 + assert count(singleton[0][0]) == 4 + assert len(to_list(singleton[0][1])) == 0 + + +def test_rev(): + a = relay.TypeVar("a") + assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a]) + + res = intrp.evaluate(rev(cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + reversed = to_list(res) + + assert len(reversed) == 3 + assert count(reversed[0]) == 3 + assert count(reversed[1]) == 2 + assert count(reversed[2]) == 1 + + +def test_unfoldr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([ + relay.FuncType([a], optional(relay.TupleType([a, b]))), a], + l(b), [a, b]) + + x = relay.Var("x", nat()) + n = relay.Var("n", nat()) + count_down = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor( + s, [relay.PatternVar(n)]), + some(relay.Tuple([n, x]))), + relay.Clause(relay.PatternConstructor(z, []), none()) + ])) + + res = intrp.evaluate(unfoldr(count_down, build_nat(3))) + unfolded = to_list(res) + + assert len(unfolded) == 3 + assert count(unfolded[0]) == 3 + assert count(unfolded[1]) == 2 + assert count(unfolded[2]) == 1 + + +def test_unfoldl(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([ + relay.FuncType([a], optional(relay.TupleType([a, b]))), a], + l(b), [a, b]) + + x = relay.Var("x", nat()) + n = relay.Var("n", nat()) + count_down = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor( + s, [relay.PatternVar(n)]), + some(relay.Tuple([n, x]))), + relay.Clause(relay.PatternConstructor(z, []), none()) + ])) + + res = intrp.evaluate(unfoldl(count_down, build_nat(3))) + unfolded = to_list(res) + + assert len(unfolded) == 3 + assert count(unfolded[0]) == 1 + assert count(unfolded[1]) == 2 + assert count(unfolded[2]) == 3 + + +def test_map_accumr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + expected_type = relay.FuncType([ + relay.FuncType([a, b], relay.TupleType([a, c])), + a, l(b) + ], relay.TupleType([a, l(c)]), [a, b, c]) + assert mod[map_accumr].checked_type == expected_type + + acc = relay.Var("acc", nat()) + x = relay.Var("x", nat()) + add_acc_to_each = relay.Function([acc, x], + relay.Tuple([add(x, acc), + add(x, acc)])) + + vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) + + sum = count(res[0]) + new_vals = to_list(res[1]) + + assert sum == 6 + assert len(new_vals) == 3 + assert count(new_vals[0]) == 6 + assert count(new_vals[1]) == 5 + assert count(new_vals[2]) == 3 + + +def test_map_accuml(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + expected_type = relay.FuncType([ + relay.FuncType([a, b], relay.TupleType([a, c])), + a, l(b) + ], relay.TupleType([a, l(c)]), [a, b, c]) + assert mod[map_accuml].checked_type == expected_type + + acc = relay.Var("acc", nat()) + x = relay.Var("x", nat()) + add_to_acc = relay.Function([acc, x], + relay.Tuple([add(x, acc), x])) + + vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) + + sum = count(res[0]) + new_vals = to_list(res[1]) + + assert sum == 6 + assert len(new_vals) == 3 + assert count(new_vals[0]) == 3 + assert count(new_vals[1]) == 2 + assert count(new_vals[2]) == 1 + + def test_optional_matching(): x = relay.Var('x') y = relay.Var('y') From 3eeca4c79f925249c08432b58de13196cd174ebe Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 14 Feb 2019 19:05:07 -0800 Subject: [PATCH 54/61] Add concat to prelude --- python/tvm/relay/prelude.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 4129f5642094..74b314b712a4 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -81,6 +81,19 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), b, [a, b]) + def define_list_concat(self): + """Defines a function that concatenates two lists. + + concat(l1, l2) : fn(list[a], list[a]) -> list[a]""" + self.concat = GlobalVar("concat") + a = TypeVar("a") + l1 = Var("l1", self.l(a)) + l2 = Var("l2", self.l(a)) + updater = Function([h, t], self.cons(h, t)) + self.mod[self.concat] = Function([l1, l2], + self.foldr(updater, l2, l1), + self.l(a), [a]) + def define_list_filter(self): """Defines a function that filters a list. @@ -342,6 +355,7 @@ def __init__(self, mod): self.define_list_map() self.define_list_foldl() self.define_list_foldr() + self.define_list_concat() self.define_list_filter() self.define_list_zip() self.define_list_rev() From 868f76e9fcc51bc72c8c804c3f1684bd1c9fdca2 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 14 Feb 2019 19:24:49 -0800 Subject: [PATCH 55/61] Add smoke test for concat --- python/tvm/relay/prelude.py | 2 ++ tests/python/relay/test_adt.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 74b314b712a4..03761b27d65a 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -89,6 +89,8 @@ def define_list_concat(self): a = TypeVar("a") l1 = Var("l1", self.l(a)) l2 = Var("l2", self.l(a)) + h = Var("h") + t = Var("t") updater = Function([h, t], self.cons(h, t)) self.mod[self.concat] = Function([l1, l2], self.foldr(updater, l2, l1), diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 3036088a45b3..a05265a46256 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -29,6 +29,7 @@ foldr = p.foldr sum = p.sum +concat = p.concat filter = p.filter zip = p.zip rev = p.rev @@ -184,6 +185,22 @@ def test_sum(): assert count(res) == 3 +def test_concat(): + a = relay.TypeVar("a") + assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a]) + + l1 = cons(build_nat(1), cons(build_nat(2), nil())) + l2 = cons(build_nat(3), cons(build_nat(4), nil())) + res = intrp.evaluate(concat(l1, l2)) + + catted = to_list(res) + assert len(catted) == 4 + assert count(catted[0]) == 1 + assert count(catted[1]) == 2 + assert count(catted[2]) == 3 + assert count(catted[3]) == 4 + + def test_filter(): a = relay.TypeVar("a") expected_type = relay.FuncType([ @@ -570,6 +587,14 @@ def test_nested_pattern_match(): test_map() test_foldl() test_foldr() + test_concat() + test_filter() + test_zip() + test_rev() + test_unfoldl() + test_unfoldr() + test_map_accumr() + test_map_accuml() test_sum() test_tmap() test_size() From 5a85aa814f3f785c7ba360d5e1a6da226f31f905 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 12:10:49 -0800 Subject: [PATCH 56/61] Correct docstrings in prelude --- python/tvm/relay/prelude.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 03761b27d65a..99b6c8d1c766 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -40,7 +40,7 @@ def define_list_map(self): def define_list_foldl(self): """Defines a left-way fold over a list. - foldl(f, z, l) : fn(fn(a, b) -> a, a, list[a]) -> a + foldl(f, z, l) : fn(fn(a, b) -> a, a, list[b]) -> a foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) evaluates to f(...f(f(f(z, a1), a2), a3)...) @@ -101,7 +101,7 @@ def define_list_filter(self): filter(f, l) : fn(fn(a) -> Tensor[(), bool], list[a]) -> list[a] - It returns a the sublist of l consisting of the elements for which f returns true. + It returns the sublist of l consisting of the elements for which f returns true. """ self.filter = GlobalVar("filter") a = TypeVar("a") From bd9bfc75d5ab8882fbab94eb5c19c1984ebf5774 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 12:13:35 -0800 Subject: [PATCH 57/61] Ensure that type defs are written in module initialization --- python/tvm/relay/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 01ac84212aa3..ef496333d828 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -42,6 +42,7 @@ def __init__(self, functions=None, type_definitions=None): k = _ty.GlobalTypeVar(k) if not isinstance(k, _ty.GlobalTypeVar): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") + mapped_type_defs[k] = v type_definitions = mapped_type_defs self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) From de7ea6ed074036c444e9a0a75550a6f55cc69941 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 12:48:17 -0800 Subject: [PATCH 58/61] Various requested renamings --- include/tvm/relay/adt.h | 16 ++++++------ include/tvm/relay/interpreter.h | 18 ++++++------- python/tvm/relay/adt.py | 24 ++++++++--------- python/tvm/relay/backend/interpreter.py | 6 ++--- src/relay/backend/interpreter.cc | 34 ++++++++++++------------- src/relay/ir/adt.cc | 18 ++++++------- src/relay/ir/alpha_equal.cc | 6 ++--- src/relay/ir/expr_functor.cc | 2 +- src/relay/ir/hash.cc | 4 +-- src/relay/ir/pattern_functor.cc | 6 ++--- src/relay/ir/text_printer.cc | 2 +- src/relay/ir/type_functor.cc | 4 +-- src/relay/pass/kind_check.cc | 8 +++--- src/relay/pass/type_infer.cc | 27 ++++++++++---------- src/relay/pass/util.cc | 2 +- tests/python/relay/test_adt.py | 22 ++++++++-------- 16 files changed, 100 insertions(+), 99 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4f5c7a995894..07c05e89aa86 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -92,7 +92,7 @@ class ConstructorNode : public ExprNode { /*! \brief The name (only a hint) */ std::string name_hint; /*! \brief Input to the constructor. */ - tvm::Array inp; + tvm::Array inputs; /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ @@ -101,12 +101,12 @@ class ConstructorNode : public ExprNode { ConstructorNode() {} TVM_DLL static Constructor make(std::string name_hint, - tvm::Array inp, + tvm::Array inputs, GlobalTypeVar belong_to); void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name_hint", &name_hint); - v->Visit("inp", &inp); + v->Visit("inputs", &inputs); v->Visit("belong_to", &belong_to); v->Visit("tag", &tag); v->Visit("span", &span); @@ -127,7 +127,7 @@ class PatternConstructorNode : public PatternNode { /*! Constructor matched by the pattern. */ Constructor constructor; /*! Sub-patterns to match against each input to the constructor. */ - tvm::Array pat; + tvm::Array patterns; PatternConstructorNode() {} @@ -135,7 +135,7 @@ class PatternConstructorNode : public PatternNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("constructor", &constructor); - v->Visit("pat", &pat); + v->Visit("patterns", &patterns); v->Visit("span", &span); } @@ -168,19 +168,19 @@ class TypeDataNode : public TypeNode { */ GlobalTypeVar header; /*! \brief The type variables (to allow for polymorphism). */ - tvm::Array ty_vars; + tvm::Array type_vars; /*! \brief The constructors. */ tvm::Array constructors; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("header", &header); - v->Visit("ty_vars", &ty_vars); + v->Visit("type_vars", &type_vars); v->Visit("constructors", &constructors); v->Visit("span", &span); } TVM_DLL static TypeData make(GlobalTypeVar header, - tvm::Array tv, + tvm::Array type_vars, tvm::Array constructors); static constexpr const char* _type_key = "relay.TypeData"; diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index f235a20065af..42f0d4e9b0a5 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -161,26 +161,26 @@ struct RefValueNode : ValueNode { RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); /*! \brief An ADT constructor value. */ -class ConValue; +class ConstructorValue; -struct ConValueNode : ValueNode { - Constructor con; +struct ConstructorValueNode : ValueNode { + Constructor constructor; tvm::Array fields; void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("con", &con); + v->Visit("constructor", &constructor); v->Visit("fields", &fields); } - TVM_DLL static ConValue make(Constructor con, - tvm::Array fields); + TVM_DLL static ConstructorValue make(Constructor constructor, + tvm::Array fields); - static constexpr const char* _type_key = "relay.ConValue"; - TVM_DECLARE_NODE_TYPE_INFO(ConValueNode, ValueNode); + static constexpr const char* _type_key = "relay.ConstructorValue"; + TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(ConValue, ConValueNode, Value); +RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 5bcfbc6817fd..bc516a8f3ddb 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -52,14 +52,14 @@ def __init__(self, var): class PatternConstructor(Pattern): """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" - def __init__(self, con, pat=None): + def __init__(self, constructor, patterns=None): """Construct a constructor pattern. Parameters ---------- - con: Constructor + constructor: Constructor The constructor. - pat: Optional[List[Pattern]] + patterns: Optional[List[Pattern]] Optional subpatterns: for each field of the constructor, match to the given subpattern (treated as a variable pattern by default). @@ -68,23 +68,23 @@ def __init__(self, con, pat=None): wildcard: PatternWildcard a wildcard pattern. """ - if pat is None: - pat = [] - self.__init_handle_by_constructor__(_make.PatternConstructor, con, pat) + if patterns is None: + patterns = [] + self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) @register_relay_node class Constructor(Expr): """Relay ADT constructor.""" - def __init__(self, name_hint, inp, belong_to): + def __init__(self, name_hint, inputs, belong_to): """Defines an ADT constructor. Parameters ---------- name_hint : str Name of constructor (only a hint). - inp : List[Type] + inputs : List[Type] Input types. belong_to : tvm.relay.GlobalTypeVar Denotes which ADT the constructor belongs to. @@ -94,7 +94,7 @@ def __init__(self, name_hint, inp, belong_to): con: Constructor A constructor. """ - self.__init_handle_by_constructor__(_make.Constructor, name_hint, inp, belong_to) + self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to) def __call__(self, *args): """Call the constructor. @@ -122,7 +122,7 @@ class TypeData(Type): type call that passes in the type params. """ - def __init__(self, header, ty_vars, constructors): + def __init__(self, header, type_vars, constructors): """Defines a TypeData object. Parameters @@ -131,7 +131,7 @@ def __init__(self, header, ty_vars, constructors): The name of the ADT. ADTs with the same constructors but different names are treated as different types. - ty_vars: List[TypeVar] + type_vars: List[TypeVar] Type variables that appear in constructors. constructors: List[tvm.relay.Constructor] The constructors for the ADT. @@ -141,7 +141,7 @@ def __init__(self, header, ty_vars, constructors): type_data: TypeData The adt declaration. """ - self.__init_handle_by_constructor__(_make.TypeData, header, ty_vars, constructors) + self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors) @register_relay_node diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 732567375fcc..1d50a571a460 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -53,10 +53,10 @@ class Closure(Value): @register_relay_node -class ConValue(Value): - def __init__(self, con, fields, types): +class ConstructorValue(Value): + def __init__(self, constructor, fields, types): self.__init_handle_by_constructor__( - _make.ConValue, con, fields, types) + _make.ConstructorValue, constructor, fields, types) @register_relay_node diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 43f17e27a839..4ef893f463e9 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -93,23 +93,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConValue ConValueNode::make(Constructor con, - tvm::Array fields) { - NodePtr n = make_node(); - n->con = con; +ConstructorValue ConstructorValueNode::make(Constructor constructor, + tvm::Array fields) { + NodePtr n = make_node(); + n->constructor = constructor; n->fields = fields; - return ConValue(n); + return ConstructorValue(n); } -TVM_REGISTER_API("relay._make.ConValue") +TVM_REGISTER_API("relay._make.ConstructorValue") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ConValueNode::make(args[0], args[1]); + *ret = ConstructorValueNode::make(args[0], args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const ConValueNode* node, - tvm::IRPrinter* p) { - p->stream << "ConValueNode(" << node->con +.set_dispatch([](const ConstructorValueNode* node, + tvm::IRPrinter* p) { + p->stream << "ConstructorValueNode(" << node->constructor << node->fields << ")"; }); @@ -424,7 +424,7 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConValueNode::make(GetRef(con), args); + return ConstructorValueNode::make(GetRef(con), args); } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); @@ -511,15 +511,15 @@ class Interpreter : } bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { - const ConValueNode* cvn = v.as(); + const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); - CHECK_NE(cvn->con->tag, -1); - if (op->constructor->tag == cvn->con->tag) { + CHECK_NE(cvn->constructor->tag, -1); + if (op->constructor->tag == cvn->constructor->tag) { // todo(M.K.): should use ptr equality but it is broken - CHECK(op->pat.size() == cvn->fields.size()); - for (size_t i = 0; i < op->pat.size(); ++i) { - if (!VisitPattern(op->pat[i], cvn->fields[i])) { + CHECK(op->patterns.size() == cvn->fields.size()); + for (size_t i = 0; i < op->patterns.size(); ++i) { + if (!VisitPattern(op->patterns[i], cvn->fields[i])) { return false; } } diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 8ea3abe02d0e..21d98036fb0d 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -47,10 +47,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); PatternConstructor PatternConstructorNode::make(Constructor constructor, - tvm::Array pat) { + tvm::Array patterns) { NodePtr n = make_node(); n->constructor = std::move(constructor); - n->pat = std::move(pat); + n->patterns = std::move(patterns); return PatternConstructor(n); } @@ -65,15 +65,15 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PatternConstructorNode* node, tvm::IRPrinter* p) { p->stream << "PatternConstructorNode(" << node->constructor - << ", " << node->pat << ")"; + << ", " << node->patterns << ")"; }); Constructor ConstructorNode::make(std::string name_hint, - tvm::Array inp, + tvm::Array inputs, GlobalTypeVar belong_to) { NodePtr n = make_node(); n->name_hint = std::move(name_hint); - n->inp = std::move(inp); + n->inputs = std::move(inputs); n->belong_to = std::move(belong_to); return Constructor(n); } @@ -89,15 +89,15 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorNode* node, tvm::IRPrinter* p) { p->stream << "ConstructorNode(" << node->name_hint << ", " - << node->inp << ", " << node->belong_to << ")"; + << node->inputs << ", " << node->belong_to << ")"; }); TypeData TypeDataNode::make(GlobalTypeVar header, - tvm::Array ty_vars, + tvm::Array type_vars, tvm::Array constructors) { NodePtr n = make_node(); n->header = std::move(header); - n->ty_vars = std::move(ty_vars); + n->type_vars = std::move(type_vars); n->constructors = std::move(constructors); return TypeData(n); } @@ -112,7 +112,7 @@ TVM_REGISTER_API("relay._make.TypeData") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeDataNode* node, tvm::IRPrinter* p) { - p->stream << "TypeDataNode(" << node->header << ", " << node->ty_vars << ", " + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " << node->constructors << ")"; }); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 8e27e65f5d8b..96517f8dd445 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -440,12 +440,12 @@ class AlphaEqualHandler: const auto* r = e2.as(); if (r == nullptr || !ExprEqual(op->constructor, r->constructor) - || op->pat.size() != r->pat.size()) { + || op->patterns.size() != r->patterns.size()) { return false; } - for (size_t i = 0; i < op->pat.size(); i++) { - if (!PatternEqual(op->pat[i], r->pat[i])) { + for (size_t i = 0; i < op->patterns.size(); i++) { + if (!PatternEqual(op->patterns[i], r->patterns[i])) { return false; } } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index ae65d21c614b..6265873d8310 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -286,7 +286,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { } void ExprVisitor::VisitExpr_(const ConstructorNode* op) { - for (const Type& t : op->inp) { + for (const Type& t : op->inputs) { this->VisitType(t); } this->VisitType(op->belong_to); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 798e3f4e20ab..94ffc7a21294 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -340,7 +340,7 @@ class RelayHashHandler: size_t VisitType_(const TypeDataNode* tdn) final { size_t hash = std::hash()(TypeDataNode::_type_key); hash = Combine(hash, TypeHash(tdn->header)); - for (const auto& tv : tdn->ty_vars) { + for (const auto& tv : tdn->type_vars) { hash = Combine(hash, TypeHash(tv)); } for (const auto& cn : tdn->constructors) { @@ -360,7 +360,7 @@ class RelayHashHandler: size_t VisitPattern_(const PatternConstructorNode* pcn) final { size_t hash = std::hash()(PatternConstructorNode::_type_key); hash = Combine(hash, ExprHash(pcn->constructor)); - for (const auto& p : pcn->pat) { + for (const auto& p : pcn->patterns) { hash = Combine(hash, PatternHash(p)); } return hash; diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index d1b5a9049142..6d2e9d296164 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -23,7 +23,7 @@ Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { std::vector pat; - for (const auto& p : op->pat) { + for (const auto& p : op->patterns) { pat.push_back(VisitPattern(p)); } return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); @@ -54,7 +54,7 @@ void PatternVisitor::VisitPattern_(const PatternVarNode* op) { void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { VisitConstructor(op->constructor); - for (const auto& p : op->pat) { + for (const auto& p : op->patterns) { VisitPattern(p); } } @@ -66,7 +66,7 @@ void PatternVisitor::VisitVar(const Var& v) { } void PatternVisitor::VisitConstructor(const Constructor& c) { - for (const auto& inp : c->inp) { + for (const auto& inp : c->inputs) { VisitType(inp); } } diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index c00e0ede3f06..932856a2055d 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -418,7 +418,7 @@ class TextPrinter : TextValue VisitPattern_(const PatternConstructorNode* p) final { TextValue ret(p->constructor->name_hint + "("); - for (const Pattern& pat : p->pat) { + for (const Pattern& pat : p->patterns) { ret = ret + " " + GetValue(pat); } return ret + ")"; diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 40df5dda4f22..b88d0ee0e3ab 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -60,13 +60,13 @@ void TypeVisitor::VisitType_(const TypeCallNode* op) { void TypeVisitor::VisitType_(const TypeDataNode* op) { this->VisitType(op->header); - for (const auto& v : op->ty_vars) { + for (const auto& v : op->type_vars) { this->VisitType(v); } for (const auto& c : op->constructors) { this->VisitType(c->belong_to); - for (const auto& t : c->inp) { + for (const auto& t : c->inputs) { this->VisitType(t); } } diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index da9d07350c5e..b769a656a0b8 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -124,8 +124,8 @@ struct KindChecker : TypeFunctor { // finally we need to check the module to check the number of type params auto var = GetRef(gtv); auto data = mod->LookupDef(var); - if (data->ty_vars.size() != op->args.size()) { - ReportFatalError(RELAY_ERROR("Expected " << data->ty_vars.size() << "arguments for " << tc + if (data->type_vars.size() != op->args.size()) { + ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc << "; got " << op->args.size())); } return Kind::kType; @@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor { TypeData td = GetRef(op); CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header"); - for (const auto& var : op->ty_vars) { + for (const auto& var : op->type_vars) { CheckKindMatches(var, td, Kind::kType, "ADT type var"); } @@ -149,7 +149,7 @@ struct KindChecker : TypeFunctor { << " but " << op << "has header " << op->header)); } - for (const Type& t : con->inp) { + for (const Type& t : con->inputs) { CheckKindMatches(t, td, Kind::kType, "ADT constructor input"); } } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ab7c0b92c03c..1815e9174779 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -211,7 +211,7 @@ class TypeInferencer : private ExprFunctor, // we can expect a certain number of arguments Array unknown_args; - for (size_t i = 0; i < td->ty_vars.size(); i++) { + for (size_t i = 0; i < td->type_vars.size(); i++) { unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); @@ -225,23 +225,23 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have " << td->header << " and " << tc->func)); } - if (td->ty_vars.size() != tc->args.size()) { + if (td->type_vars.size() != tc->args.size()) { this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match" << "the number of type vars in the type data: " - << td->ty_vars.size() << " != " << tc->args.size())); + << td->type_vars.size() << " != " << tc->args.size())); } std::unordered_map type_var_map_; - for (size_t i = 0; i < td->ty_vars.size(); ++i) { - type_var_map_[td->ty_vars[i]] = tc->args[i]; + for (size_t i = 0; i < td->type_vars.size(); ++i) { + type_var_map_[td->type_vars[i]] = tc->args[i]; } - CHECK(con->constructor->inp.size() == con->pat.size()) << "not enough pattern"; - if (con->constructor->inp.size() != con->pat.size()) { + CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; + if (con->constructor->inputs.size() != con->patterns.size()) { this->ReportFatalError(pc, RELAY_ERROR("Not enough inputs for the constructor; " - << "expected " << con->constructor->inp.size() - << ", got " << con->pat.size())); + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size())); } - for (size_t i = 0; i < con->constructor->inp.size(); ++i) { - VisitPattern(con->pat[i], Bind(con->constructor->inp[i], type_var_map_)); + for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { + VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); } } @@ -495,10 +495,11 @@ class TypeInferencer : private ExprFunctor, << c->name_hint; TypeData td = mod_->LookupDef(c->belong_to); std::vector types; - for (const auto & t : td->ty_vars) { + for (const auto & t : td->type_vars) { types.push_back(t); } - return FuncTypeNode::make(c->inp, TypeCallNode::make(c->belong_to, types), td->ty_vars, {}); + return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), + td->type_vars, {}); } }; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index ce49aa84e91c..76fc0aa1a45e 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -121,7 +121,7 @@ class TypeVarEVisitor : private ExprVisitor { void VisitExpr_(const ConstructorNode* cn) final { // for constructors, type vars will be bound in the module auto data = mod_->LookupDef(cn->belong_to); - for (const auto& tv : data->ty_vars) { + for (const auto& tv : data->type_vars) { type_vars_.Insert(tv); bound_type_vars_.Insert(tv); } diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index a05265a46256..5acae6c70295 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -1,7 +1,7 @@ import tvm from tvm import relay from tvm.relay.ir_pass import infer_type -from tvm.relay.backend.interpreter import Value, TupleValue, ConValue +from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay import testing, create_executor from tvm.relay.prelude import Prelude @@ -45,19 +45,19 @@ # this is an example of using the adt value in python side def count(n): - assert isinstance(n, ConValue) - if n.con.name_hint == 's': + assert isinstance(n, ConstructorValue) + if n.constructor.name_hint == 's': return 1 + count(n.fields[0]) else: - assert n.con.name_hint == 'z' + assert n.constructor.name_hint == 'z' return 0 # this is an example of creating the adt value in python side def make_nat(n): if n != 0: - return ConValue(s, [make_nat(n - 1)], []) + return ConstructorValue(s, [make_nat(n - 1)], []) else: - return ConValue(z, [], []) + return ConstructorValue(z, [], []) def build_nat(n): assert n >= 0 @@ -68,22 +68,22 @@ def build_nat(n): return ret def to_list(l): - assert isinstance(l, ConValue) + assert isinstance(l, ConstructorValue) val = l ret = [] while True: - if val.con.name_hint == 'cons': + if val.constructor.name_hint == 'cons': ret.append(val.fields[0]) val = val.fields[1] else: - assert val.con.name_hint == 'nil' + assert val.constructor.name_hint == 'nil' break return ret def tree_to_dict(t): - assert isinstance(t, ConValue) + assert isinstance(t, ConstructorValue) ret = {} - assert t.con.name_hint == 'rose' + assert t.constructor.name_hint == 'rose' ret['member'] = t.fields[0] ret['children'] = [] for subtree in to_list(t.fields[1]): From 8909e187823fbaf44261e8c495b0079574388eb3 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 13:24:21 -0800 Subject: [PATCH 59/61] Correct rebase snags --- src/relay/ir/hash.cc | 1 - src/relay/pass/kind_check.cc | 6 ++---- src/relay/pass/type_infer.cc | 4 ++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 94ffc7a21294..5e10906bec84 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -19,7 +19,6 @@ namespace relay { class RelayHashHandler: public AttrsHashHandler, public TypeFunctor, - public ExprFunctor { public ExprFunctor, public PatternFunctor { public: diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index b769a656a0b8..f1e539d71d48 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -90,10 +90,8 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const RefTypeNode* op) override { // ref types should only contain normal types - Kind k = this->VisitType(op->value); - CHECK(k == Kind::kType) - << "The value inside a ref must be of the type kind but " - << op->value << " of " << GetRef(op) << " is of kind " << k; + RefType rt = GetRef(op); + CheckKindMatches(op->value, rt, Kind::kType, "ref contents"); return Kind::kType; } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1815e9174779..fa3cea610c68 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -477,13 +477,13 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const RefReadNode* op) final { - Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type it = IncompleteTypeNode::make(Kind::kType); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); return it; } Type VisitExpr_(const RefWriteNode* op) final { - Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type it = IncompleteTypeNode::make(Kind::kType); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); this->Unify(GetType(op->value), it, GetRef(op)); return TupleTypeNode::make({}); From ffd5c80524beaa3c0a230315325e89e076d2c5fb Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 13:29:40 -0800 Subject: [PATCH 60/61] Add kind check tests for ref types --- tests/python/relay/test_pass_check_kind.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index dafb789fda73..ae5e3738847a 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -46,6 +46,19 @@ def test_func_kind(): assert check_kind(tf) == relay.Kind.Type +def test_ref_kind(): + # only contain type kinds + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + ft = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) + + rt1 = relay.RefType(tt) + assert check_kind(rt1) == relay.Kind.Type + rt2 = relay.RefType(ft) + assert check_kind(rt2) == relay.Kind.Type + rt3 = relay.RefType(relay.TupleType([rt1, rt2])) + assert check_kind(rt3) == relay.Kind.Type + + def test_relation_kind(): # only have type kinds for arguments tp = relay.TypeVar('tp', relay.Kind.Type) @@ -108,6 +121,13 @@ def test_invalid_func_kind(): check_kind(tf) +@raises(tvm._ffi.base.TVMError) +def test_invalid_ref_kind(): + tp = relay.TypeVar('tp', relay.Kind.Shape) + rt = relay.RefType(tp) + check_kind(rt) + + @raises(tvm._ffi.base.TVMError) def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) From 7a6eff0e59017ce5469b49177ae92eb882bb6632 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 13:32:25 -0800 Subject: [PATCH 61/61] Update the main() for kind checking tests --- tests/python/relay/test_pass_check_kind.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index ae5e3738847a..4eab59a6edd0 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -223,10 +223,17 @@ def test_tuple_with_invalid_func(): if __name__ == "__main__": test_tuple_kind() test_func_kind() + test_ref_kind() test_relation_kind() + test_global_typevar_kind() + test_typecall_kind() test_invalid_tuple_kind() test_invalid_func_kind() + test_invalid_ref_kind() test_invalid_relation_kind() + test_typecall_invalid_callee() + test_typecall_invalid_args() + test_typecall_invalid_num_args() test_func_with_invalid_ret_type() test_func_with_invalid_arg_types() test_func_with_invalid_tuple()