Skip to content

Commit

Permalink
First pass on ADTs
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and slyubomirsky committed Jan 16, 2019
1 parent 9e9cda5 commit a6b4cad
Show file tree
Hide file tree
Showing 33 changed files with 1,934 additions and 73 deletions.
228 changes: 228 additions & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"

namespace tvm {
namespace relay {

/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};

/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and bings to the first that accepts the value.
*/
class Pattern : public NodeRef {
public:
Pattern() {}
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}

using ContainerType = PatternNode;
};

/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}

TVM_DLL static PatternWildcard make();

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}

tvm::relay::Var var;

TVM_DLL static PatternVar make(tvm::relay::Var var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);

/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inp;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
mutable int tag = -1;

ConstructorNode() {}

TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inp,
GlobalTypeVar belong_to);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("inp", &inp);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
Constructor con;
tvm::Array<Pattern> pat;

PatternConstructorNode() {}

TVM_DLL static PatternConstructor make(Constructor con, tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("con", &con);
v->Visit("pat", &pat);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);

/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> tv;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("header", &header);
v->Visit("tv", &tv);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}

TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> tv,
tvm::Array<Constructor> constructors);

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);

/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
public:
Pattern lhs;
Expr rhs;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}

TVM_DLL static Clause make(Pattern lhs, Expr rhs);

static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};

RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);

/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
Expr data;

tvm::Array<Clause> pattern;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("pattern", &pattern);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);

static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ADT_H_
14 changes: 14 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
#include "./error.h"

Expand Down Expand Up @@ -89,6 +90,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +111,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}
};
Expand All @@ -133,7 +138,11 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);

protected:
// Internal visiting counter
Expand Down Expand Up @@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
*
Expand All @@ -176,6 +188,8 @@ class ExprMutator
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);

protected:
/*! \brief Internal map used for memoization. */
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,28 @@ struct TensorValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);

/*! \brief An ADT constructor value. */
class ConValue;

struct ConValueNode : ValueNode {
Constructor con;

tvm::Array<Value> fields;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("con", &con);
v->Visit("fields", &fields);
}

TVM_DLL static ConValue make(Constructor con,
tvm::Array<Value> fields);

static constexpr const char* _type_key = "relay.ConValue";
TVM_DECLARE_NODE_TYPE_INFO(ConValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(ConValue, ConValueNode, Value);


} // namespace relay
} // namespace tvm
Expand Down
Loading

0 comments on commit a6b4cad

Please sign in to comment.