-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Algebraic data types (#2442)
* First pass on ADTs * Add doc string for tag field * Visit constructors in TypeVisitor for TypeData * Add to description of type call * Add type call to type solving and unification * Make type mutator for typecall consistent with others (only create new node if there's a change) * Ensure kindchecking can handle type calls and typedata * Fix bad nesting in module constructor * Correctly construct call in typecall test * Add call override for ordinary vars (do we want this?) * Remove generalization hack from type inference because it was breaking ADT constructors * Check that there are no free type vars in exprs after inferring type * Free var checks need module because of ADT constructors * Typecall test can't have unbound type var, make it global * Uncomment tmap test and remove comments about failing to infer ret type; those work now * Put in dummy visits for ADTs in graph runtime codegen to placate pylint * Fix Relay type infer test module constructor * Mark override for TypeCallNode in type solver * Ensure free vars check treats patern vars as bound * Run interpreter in more ADT test cases * Refactor kind check to return the kind, like typechecking * Fix invalid typecall in test * Add kind check to type inference, do not use nulls in func_type_annotation()! * Redundant whitespace * Make TypeData a separate kind * Make ADT handles a separate kind too, document calling convention better * Remove nats and tree from prelude, move to test, document prelude * Restore and document nat and tree to prelude, add more tree tests * Add alpha equality tests for match cases, fix variable binding bug * Add more kind check tests for ADTs * Add more tests for finding free or bound vars in match exprs * Add unification tests for type call * Update main() for alpha equality tests * Add simple type inference test cases for match exprs and ADT constructors * Add more ADT interpreter tests * Allow incomplete types when typechecking match cases * Type inference for pattern vars should use the type annotation if it's there * Two more specific test cases for ADT matching * Add option ADT to prelude * Fix broken reference to kind enum * Fix rebase snags * Do not attach checked types to constructors * More docstrings for module fields * Use proper wrapper for indexing into module type data * checked_type for constructors is not populated * Expand type call docstring * Rename PatternConstructor con field * Use error reporter for pattern constructor case * Condense error reporting in kind check, use error reporter * Expand docstrings and rename ADT fields * Rename 'option' ADT to 'optional' for consistency with Python * Add various list iterators and utility functions to prelude * Add smoke tests for new iterators in prelude * Add concat to prelude * Add smoke test for concat * Correct docstrings in prelude * Ensure that type defs are written in module initialization * Various requested renamings * Correct rebase snags * Add kind check tests for ref types * Update the main() for kind checking tests
- Loading branch information
1 parent
c716542
commit 2ae3124
Showing
45 changed files
with
3,398 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file tvm/relay/adt.h | ||
* \brief Algebraic data types for Relay | ||
*/ | ||
#ifndef TVM_RELAY_ADT_H_ | ||
#define TVM_RELAY_ADT_H_ | ||
|
||
#include <tvm/attrs.h> | ||
#include <string> | ||
#include <functional> | ||
#include "./base.h" | ||
#include "./type.h" | ||
#include "./expr.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! \brief Base type for declaring relay pattern. */ | ||
class PatternNode : public RelayNode { | ||
public: | ||
static constexpr const char* _type_key = "relay.Pattern"; | ||
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); | ||
}; | ||
|
||
/*! | ||
* \brief Pattern is the base type for an ADT match pattern in Relay. | ||
* | ||
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value | ||
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. | ||
* | ||
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value. | ||
*/ | ||
class Pattern : public NodeRef { | ||
public: | ||
Pattern() {} | ||
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {} | ||
|
||
using ContainerType = PatternNode; | ||
}; | ||
|
||
/*! \brief A wildcard pattern: Accepts all input and binds nothing. */ | ||
class PatternWildcard; | ||
/*! \brief PatternWildcard container node */ | ||
class PatternWildcardNode : public PatternNode { | ||
public: | ||
PatternWildcardNode() {} | ||
|
||
TVM_DLL static PatternWildcard make(); | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("span", &span); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.PatternWildcard"; | ||
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); | ||
|
||
/*! \brief A var pattern. Accept all input and bind to a var. */ | ||
class PatternVar; | ||
/*! \brief PatternVar container node */ | ||
class PatternVarNode : public PatternNode { | ||
public: | ||
PatternVarNode() {} | ||
|
||
/*! \brief Variable that stores the matched value. */ | ||
tvm::relay::Var var; | ||
|
||
TVM_DLL static PatternVar make(tvm::relay::Var var); | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("var", &var); | ||
v->Visit("span", &span); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.PatternVar"; | ||
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); | ||
|
||
/*! | ||
* \brief ADT constructor. | ||
* Constructors compare by pointer equality. | ||
*/ | ||
class Constructor; | ||
/*! \brief Constructor container node. */ | ||
class ConstructorNode : public ExprNode { | ||
public: | ||
/*! \brief The name (only a hint) */ | ||
std::string name_hint; | ||
/*! \brief Input to the constructor. */ | ||
tvm::Array<Type> inputs; | ||
/*! \brief The datatype the constructor will construct. */ | ||
GlobalTypeVar belong_to; | ||
/*! \brief Index in the table of constructors (set when the type is registered). */ | ||
mutable int tag = -1; | ||
|
||
ConstructorNode() {} | ||
|
||
TVM_DLL static Constructor make(std::string name_hint, | ||
tvm::Array<Type> inputs, | ||
GlobalTypeVar belong_to); | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("name_hint", &name_hint); | ||
v->Visit("inputs", &inputs); | ||
v->Visit("belong_to", &belong_to); | ||
v->Visit("tag", &tag); | ||
v->Visit("span", &span); | ||
v->Visit("_checked_type_", &checked_type_); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.Constructor"; | ||
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); | ||
|
||
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ | ||
class PatternConstructor; | ||
/*! \brief PatternVar container node */ | ||
class PatternConstructorNode : public PatternNode { | ||
public: | ||
/*! Constructor matched by the pattern. */ | ||
Constructor constructor; | ||
/*! Sub-patterns to match against each input to the constructor. */ | ||
tvm::Array<Pattern> patterns; | ||
|
||
PatternConstructorNode() {} | ||
|
||
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var); | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("constructor", &constructor); | ||
v->Visit("patterns", &patterns); | ||
v->Visit("span", &span); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.PatternConstructor"; | ||
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); | ||
|
||
/*! | ||
* \brief Stores all data for an Algebraic Data Type (ADT). | ||
* | ||
* In particular, it stores the handle (global type var) for an ADT | ||
* and the constructors used to build it and is kept in the module. Note | ||
* that type parameters are also indicated in the type data: this means that | ||
* for any instance of an ADT, the type parameters must be indicated. That is, | ||
* an ADT definition is treated as a type-level function, so an ADT handle | ||
* must be wrapped in a TypeCall node that instantiates the type-level arguments. | ||
* The kind checker enforces this. | ||
*/ | ||
class TypeData; | ||
/*! \brief TypeData container node */ | ||
class TypeDataNode : public TypeNode { | ||
public: | ||
/*! | ||
* \brief The header is simply the name of the ADT. | ||
* We adopt nominal typing for ADT definitions; | ||
* that is, differently-named ADT definitions with same constructors | ||
* have different types. | ||
*/ | ||
GlobalTypeVar header; | ||
/*! \brief The type variables (to allow for polymorphism). */ | ||
tvm::Array<TypeVar> type_vars; | ||
/*! \brief The constructors. */ | ||
tvm::Array<Constructor> constructors; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("header", &header); | ||
v->Visit("type_vars", &type_vars); | ||
v->Visit("constructors", &constructors); | ||
v->Visit("span", &span); | ||
} | ||
|
||
TVM_DLL static TypeData make(GlobalTypeVar header, | ||
tvm::Array<TypeVar> type_vars, | ||
tvm::Array<Constructor> constructors); | ||
|
||
static constexpr const char* _type_key = "relay.TypeData"; | ||
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); | ||
|
||
/*! \brief A clause in a match expression. */ | ||
class Clause; | ||
/*! \brief Clause container node. */ | ||
class ClauseNode : public Node { | ||
public: | ||
/*! \brief The pattern the clause matches. */ | ||
Pattern lhs; | ||
/*! \brief The resulting value. */ | ||
Expr rhs; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("lhs", &lhs); | ||
v->Visit("rhs", &rhs); | ||
} | ||
|
||
TVM_DLL static Clause make(Pattern lhs, Expr rhs); | ||
|
||
static constexpr const char* _type_key = "relay.Clause"; | ||
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); | ||
|
||
/*! \brief ADT pattern matching exression. */ | ||
class Match; | ||
/*! \brief Match container node. */ | ||
class MatchNode : public ExprNode { | ||
public: | ||
/*! \brief The input being deconstructed. */ | ||
Expr data; | ||
|
||
/*! \brief The match node clauses. */ | ||
tvm::Array<Clause> clauses; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("data", &data); | ||
v->Visit("clause", &clauses); | ||
v->Visit("span", &span); | ||
v->Visit("_checked_type_", &checked_type_); | ||
} | ||
|
||
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern); | ||
|
||
static constexpr const char* _type_key = "relay.Match"; | ||
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_ADT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.