Skip to content

Commit

Permalink
[Relay] Algebraic data types (apache#2442)
Browse files Browse the repository at this point in the history
* First pass on ADTs

* Add doc string for tag field

* Visit constructors in TypeVisitor for TypeData

* Add to description of type call

* Add type call to type solving and unification

* Make type mutator for typecall consistent with others (only create new node if there's a change)

* Ensure kindchecking can handle type calls and typedata

* Fix bad nesting in module constructor

* Correctly construct call in typecall test

* Add call override for ordinary vars (do we want this?)

* Remove generalization hack from type inference because it was breaking ADT constructors

* Check that there are no free type vars in exprs after inferring type

* Free var checks need module because of ADT constructors

* Typecall test can't have unbound type var, make it global

* Uncomment tmap test and remove comments about failing to infer ret type; those work now

* Put in dummy visits for ADTs in graph runtime codegen to placate pylint

* Fix Relay type infer test module constructor

* Mark override for TypeCallNode in type solver

* Ensure free vars check treats patern vars as bound

* Run interpreter in more ADT test cases

* Refactor kind check to return the kind, like typechecking

* Fix invalid typecall in test

* Add kind check to type inference, do not use nulls in func_type_annotation()!

* Redundant whitespace

* Make TypeData a separate kind

* Make ADT handles a separate kind too, document calling convention better

* Remove nats and tree from prelude, move to test, document prelude

* Restore and document nat and tree to prelude, add more tree tests

* Add alpha equality tests for match cases, fix variable binding bug

* Add more kind check tests for ADTs

* Add more tests for finding free or bound vars in match exprs

* Add unification tests for type call

* Update main() for alpha equality tests

* Add simple type inference test cases for match exprs and ADT constructors

* Add more ADT interpreter tests

* Allow incomplete types when typechecking match cases

* Type inference for pattern vars should use the type annotation if it's there

* Two more specific test cases for ADT matching

* Add option ADT to prelude

* Fix broken reference to kind enum

* Fix rebase snags

* Do not attach checked types to constructors

* More docstrings for module fields

* Use proper wrapper for indexing into module type data

* checked_type for constructors is not populated

* Expand type call docstring

* Rename PatternConstructor con field

* Use error reporter for pattern constructor case

* Condense error reporting in kind check, use error reporter

* Expand docstrings and rename ADT fields

* Rename 'option' ADT to 'optional' for consistency with Python

* Add various list iterators and utility functions to prelude

* Add smoke tests for new iterators in prelude

* Add concat to prelude

* Add smoke test for concat

* Correct docstrings in prelude

* Ensure that type defs are written in module initialization

* Various requested renamings

* Correct rebase snags

* Add kind check tests for ref types

* Update the main() for kind checking tests
  • Loading branch information
slyubomirsky authored and AWS Neo committed Feb 20, 2019
1 parent 24eb8ee commit 921d084
Show file tree
Hide file tree
Showing 45 changed files with 3,398 additions and 207 deletions.
244 changes: 244 additions & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"

namespace tvm {
namespace relay {

/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};

/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public NodeRef {
public:
Pattern() {}
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}

using ContainerType = PatternNode;
};

/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}

TVM_DLL static PatternWildcard make();

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}

/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;

TVM_DLL static PatternVar make(tvm::relay::Var var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);

/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;

ConstructorNode() {}

TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
/*! Constructor matched by the pattern. */
Constructor constructor;
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;

PatternConstructorNode() {}

TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);

/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}

TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);

/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
/*! \brief The resulting value. */
Expr rhs;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}

TVM_DLL static Clause make(Pattern lhs, Expr rhs);

static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};

RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);

/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
/*! \brief The input being deconstructed. */
Expr data;

/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);

static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ADT_H_
14 changes: 14 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
#include "./error.h"

Expand Down Expand Up @@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}
};
Expand Down Expand Up @@ -142,7 +147,11 @@ class ExprVisitor
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);

protected:
// Internal visiting counter
Expand Down Expand Up @@ -180,6 +189,9 @@ class ExprMutator
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
*
Expand All @@ -188,6 +200,8 @@ class ExprMutator
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);

protected:
/*! \brief Internal map used for memoization. */
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ struct RefValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
Constructor constructor;

tvm::Array<Value> fields;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("fields", &fields);
}

TVM_DLL static ConstructorValue make(Constructor constructor,
tvm::Array<Value> fields);

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_
Loading

0 comments on commit 921d084

Please sign in to comment.