Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Algebraic data types #2442

Merged
merged 61 commits into from
Feb 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
f3b58c3
First pass on ADTs
MarisaKirisame Jan 16, 2019
485889d
Add doc string for tag field
slyubomirsky Jan 16, 2019
54e5196
Visit constructors in TypeVisitor for TypeData
slyubomirsky Jan 16, 2019
671f615
Add to description of type call
slyubomirsky Jan 16, 2019
f9e48e8
Add type call to type solving and unification
slyubomirsky Jan 17, 2019
94748c1
Make type mutator for typecall consistent with others (only create ne…
slyubomirsky Jan 17, 2019
7b1126f
Ensure kindchecking can handle type calls and typedata
slyubomirsky Jan 17, 2019
f2d36d6
Fix bad nesting in module constructor
slyubomirsky Jan 21, 2019
7dc26ef
Correctly construct call in typecall test
slyubomirsky Jan 22, 2019
0c949c7
Add call override for ordinary vars (do we want this?)
slyubomirsky Jan 22, 2019
9443bdc
Remove generalization hack from type inference because it was breakin…
slyubomirsky Jan 22, 2019
33fc344
Check that there are no free type vars in exprs after inferring type
slyubomirsky Jan 22, 2019
c208071
Free var checks need module because of ADT constructors
slyubomirsky Jan 22, 2019
af7da11
Typecall test can't have unbound type var, make it global
slyubomirsky Jan 22, 2019
4968b48
Uncomment tmap test and remove comments about failing to infer ret ty…
slyubomirsky Jan 22, 2019
87f82b7
Put in dummy visits for ADTs in graph runtime codegen to placate pylint
slyubomirsky Jan 22, 2019
e0f4c08
Fix Relay type infer test module constructor
slyubomirsky Jan 22, 2019
b646740
Mark override for TypeCallNode in type solver
slyubomirsky Jan 22, 2019
e11f58e
Ensure free vars check treats patern vars as bound
slyubomirsky Jan 23, 2019
ec56e4a
Run interpreter in more ADT test cases
slyubomirsky Jan 23, 2019
4f74545
Refactor kind check to return the kind, like typechecking
slyubomirsky Jan 23, 2019
276e028
Fix invalid typecall in test
slyubomirsky Jan 23, 2019
7963e7e
Add kind check to type inference, do not use nulls in func_type_annot…
slyubomirsky Jan 23, 2019
e2d6219
Redundant whitespace
slyubomirsky Jan 23, 2019
40c7410
Make TypeData a separate kind
slyubomirsky Jan 23, 2019
673bcd6
Make ADT handles a separate kind too, document calling convention better
slyubomirsky Jan 23, 2019
5e61378
Remove nats and tree from prelude, move to test, document prelude
slyubomirsky Jan 23, 2019
3db3c64
Restore and document nat and tree to prelude, add more tree tests
slyubomirsky Jan 23, 2019
0041b46
Add alpha equality tests for match cases, fix variable binding bug
slyubomirsky Jan 24, 2019
d232beb
Add more kind check tests for ADTs
slyubomirsky Jan 24, 2019
7c6d737
Add more tests for finding free or bound vars in match exprs
slyubomirsky Jan 24, 2019
7322866
Add unification tests for type call
slyubomirsky Jan 24, 2019
5f3a2f4
Update main() for alpha equality tests
slyubomirsky Jan 24, 2019
90ee405
Add simple type inference test cases for match exprs and ADT construc…
slyubomirsky Jan 24, 2019
d4a54a1
Add more ADT interpreter tests
slyubomirsky Jan 24, 2019
609f56e
Allow incomplete types when typechecking match cases
slyubomirsky Jan 24, 2019
089813a
Type inference for pattern vars should use the type annotation if it'…
slyubomirsky Jan 24, 2019
ebec99c
Two more specific test cases for ADT matching
slyubomirsky Jan 24, 2019
00963de
Add option ADT to prelude
slyubomirsky Jan 24, 2019
47babdb
Fix broken reference to kind enum
slyubomirsky Jan 25, 2019
a37d927
Fix rebase snags
slyubomirsky Jan 28, 2019
0c660aa
Do not attach checked types to constructors
slyubomirsky Jan 31, 2019
f5cec3e
More docstrings for module fields
slyubomirsky Feb 1, 2019
b56bc36
Use proper wrapper for indexing into module type data
slyubomirsky Feb 1, 2019
6b8dbb8
checked_type for constructors is not populated
slyubomirsky Feb 1, 2019
07ea915
Expand type call docstring
slyubomirsky Feb 1, 2019
b7cfc59
Rename PatternConstructor con field
slyubomirsky Feb 1, 2019
8cd15f2
Use error reporter for pattern constructor case
slyubomirsky Feb 4, 2019
1d9ae48
Condense error reporting in kind check, use error reporter
slyubomirsky Feb 4, 2019
acc2ec0
Expand docstrings and rename ADT fields
slyubomirsky Feb 4, 2019
737514b
Rename 'option' ADT to 'optional' for consistency with Python
slyubomirsky Feb 11, 2019
1a6e48a
Add various list iterators and utility functions to prelude
slyubomirsky Feb 15, 2019
511a931
Add smoke tests for new iterators in prelude
slyubomirsky Feb 15, 2019
3eeca4c
Add concat to prelude
slyubomirsky Feb 15, 2019
868f76e
Add smoke test for concat
slyubomirsky Feb 15, 2019
5a85aa8
Correct docstrings in prelude
slyubomirsky Feb 15, 2019
bd9bfc7
Ensure that type defs are written in module initialization
slyubomirsky Feb 15, 2019
de7ea6e
Various requested renamings
slyubomirsky Feb 15, 2019
8909e18
Correct rebase snags
slyubomirsky Feb 15, 2019
ffd5c80
Add kind check tests for ref types
slyubomirsky Feb 15, 2019
7a6eff0
Update the main() for kind checking tests
slyubomirsky Feb 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"

namespace tvm {
namespace relay {

/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};

/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public NodeRef {
public:
Pattern() {}
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}

using ContainerType = PatternNode;
};

/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}

TVM_DLL static PatternWildcard make();

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}

/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;

TVM_DLL static PatternVar make(tvm::relay::Var var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);

/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
slyubomirsky marked this conversation as resolved.
Show resolved Hide resolved

ConstructorNode() {}

TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
/*! Constructor matched by the pattern. */
Constructor constructor;
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;

PatternConstructorNode() {}

TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);

/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
slyubomirsky marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}

TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);

/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
/*! \brief The resulting value. */
Expr rhs;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}

TVM_DLL static Clause make(Pattern lhs, Expr rhs);

static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};

RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);

/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
/*! \brief The input being deconstructed. */
Expr data;

/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);

static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ADT_H_
14 changes: 14 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
#include "./error.h"

Expand Down Expand Up @@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}
};
Expand Down Expand Up @@ -142,7 +147,11 @@ class ExprVisitor
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);

protected:
// Internal visiting counter
Expand Down Expand Up @@ -180,6 +189,9 @@ class ExprMutator
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
*
Expand All @@ -188,6 +200,8 @@ class ExprMutator
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);

protected:
/*! \brief Internal map used for memoization. */
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ struct RefValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
Constructor constructor;

tvm::Array<Value> fields;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("fields", &fields);
}

TVM_DLL static ConstructorValue make(Constructor constructor,
tvm::Array<Value> fields);

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_
Loading