Skip to content

Commit

Permalink
Add implementation of alpha equivalence (apache#35)
Browse files Browse the repository at this point in the history
* Refactor base types to reuse HalideIR::Type

* Tweak BaseType constructors

* Add start on alpha_eq for expressions

* Fix free_vars.h

* Add alpha_eq implementation and tests

* Restore reverse_ad

* Add test skeleton

* Refactor type_visitor.h

* Update free_type_vars.cc

* Stub type alpha-equivalence

* Fix style in AlphaEq

* Fill in unfinished test cases

* Get half of the tests passing

* More changes, only 13 failures left

* Fix a few more bugs

* Fix Tensor and Product cases

* Fix cast case

* Fix test cases

* Fix lint

* Fix remaining test cases
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 7173216 commit 7babba8
Show file tree
Hide file tree
Showing 20 changed files with 601 additions and 268 deletions.
19 changes: 19 additions & 0 deletions relay/include/relay/alpha_eq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*!
* Copyright (c) 2018 by Contributors
* \file alpha_eq.h
* \brief Check expressions for structural equivalence.
*/
#ifndef NNVM_RELAY_ALPHA_EQ_H_
#define NNVM_RELAY_ALPHA_EQ_H_

#include "node.h"

namespace nnvm {
namespace relay {

bool alpha_eq(const Expr & e1, const Expr & e2);
bool alpha_eq(const Type & t1, const Type & t2);

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_ALPHA_EQ_H_
1 change: 1 addition & 0 deletions relay/include/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(DebugNode);
RELAY_EXPR_FUNCTOR_DISPATCH(UnaryOpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(BinaryOpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
Expand Down
97 changes: 77 additions & 20 deletions relay/include/relay/expr_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,84 @@
namespace nnvm {
namespace relay {

class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
template <typename... Args>
class ExprVisitor : public ExprFunctor<void(const Expr& n, Args...)> {
public:
void VisitExpr_(const LocalIdNode* op) override;
void VisitExpr_(const GlobalIdNode* op) override;
void VisitExpr_(const IntrinsicIdNode* op) override;
void VisitExpr_(const FloatLitNode* op) override;
void VisitExpr_(const BoolLitNode* op) override;
void VisitExpr_(const IntLitNode* op) override;
void VisitExpr_(const TensorLitNode* op) override;
void VisitExpr_(const ProductLitNode* op) override;
void VisitExpr_(const CastNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const DebugNode* op) override;
void VisitExpr_(const UnaryOpNode* op) override;
void VisitExpr_(const BinaryOpNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const ReverseNode* op) override;
void VisitExpr_(const AccumulateNode* op) override;
void VisitExpr_(const ZeroNode* op) override;
void VisitExpr_(const LocalIdNode* op, Args... args) override { return; }

void VisitExpr_(const GlobalIdNode* op, Args... args) override { return; }

void VisitExpr_(const IntrinsicIdNode* op, Args... args) override { return; }

void VisitExpr_(const FloatLitNode* op, Args... args) override { return; }

void VisitExpr_(const BoolLitNode* op, Args... args) override { return; }

void VisitExpr_(const IntLitNode* op, Args... args) override { return; }

void VisitExpr_(const TensorLitNode* op, Args... args) override {
// todo
return;
}

void VisitExpr_(const ProductLitNode* op, Args... args) override {
// todo
return;
}

void VisitExpr_(const CastNode* op, Args... args) override {
this->VisitExpr(op->node, args...);
}

void VisitExpr_(const ParamNode* op, Args... args) override {
this->VisitExpr(op->id, args...);
}

void VisitExpr_(const FunctionNode* op, Args... args) override {
for (auto param : op->params) {
this->VisitExpr(param, args...);
}

this->VisitExpr(op->body, args...);
}

void VisitExpr_(const CallNode* op, Args... args) override {
this->VisitExpr(op->fn, args...);
for (auto arg : op->args) {
this->VisitExpr(arg, args...);
}
}

void VisitExpr_(const DebugNode* op, Args... args) override {
this->VisitExpr(op->node, args...);
}

void VisitExpr_(const UnaryOpNode* op, Args... args) override {
this->VisitExpr(op->node, args...);
}

void VisitExpr_(const BinaryOpNode* op, Args... args) override {
this->VisitExpr(op->left, args...);
this->VisitExpr(op->right, args...);
}
void VisitExpr_(const LetNode* op, Args... args) override {
this->VisitExpr(op->id, args...);
this->VisitExpr(op->value, args...);
}

void VisitExpr_(const ReverseNode* op, Args... args) override {
this->VisitExpr(op->node, args...);
}

void VisitExpr_(const AccumulateNode* op, Args... args) override {
// todo
return;
}

void VisitExpr_(const ZeroNode* op, Args... args) override {
// todo
return;
}
};

} // namespace relay
Expand Down
6 changes: 0 additions & 6 deletions relay/include/relay/free_type_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
namespace nnvm {
namespace relay {

struct FreeTypeVars : TypeVisitor {
public:
std::set<TypeVar> free_vars;
void VisitType_(const TypeVar& op) override;
};

std::set<TypeVar> free_type_vars(const Type& e);

} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion relay/include/relay/free_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace nnvm {
namespace relay {

struct FreeVars : ExprVisitor {
struct FreeVars : ExprVisitor<> {
public:
std::set<LocalId> free_vars;

Expand Down
94 changes: 42 additions & 52 deletions relay/include/relay/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,39 @@

#include <nnvm/node.h>
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <string>

/*! \brief Macro to make it easy to define node ref type given node */
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefPtr) \
class TypeName : public NodeRefPtr { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefPtr(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
}; \
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefPtr) \
class TypeName : public NodeRefPtr { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefPtr(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};

/*! \brief Macro to make it easy to define node ref type given node */
#define RELAY_DEFINE_EXPR(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Expr)
#define RELAY_DEFINE_EXPR(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Expr)

/*! \brief Macro to make it easy to define node ref type given node */
#define RELAY_DEFINE_VALUE(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Value)
#define RELAY_DEFINE_VALUE(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Value)

/*! \brief Macro to make it easy to define node ref type given node */
#define RELAY_DEFINE_TYPE(TypeName, NodeName) RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type)
#define RELAY_DEFINE_TYPE(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type)

namespace nnvm {
namespace relay {

typedef HalideIR::Type HType;

struct Node : public tvm::Node {};

/*!
Expand Down Expand Up @@ -141,8 +145,6 @@ class FloatValueNode : public ValueNode {

RELAY_DEFINE_VALUE(FloatValue, FloatValueNode);



// end move me

/*! \brief Base type of the Relay type hiearchy. */
Expand All @@ -163,43 +165,36 @@ struct Type : public NodeRef {
using ContainerType = TypeNode;
};

class IntType;
class BaseType;

/*! \brief The type of integer values. */
class IntTypeNode : public TypeNode {
class BaseTypeNode : public TypeNode {
public:
unsigned width;

IntTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("width", reinterpret_cast<int*>(&width));
}
HalideIR::Type type;

TVM_DLL static IntType make(int width);
BaseTypeNode() {}

static constexpr const char* _type_key = "nnvm.IntType";
TVM_DECLARE_NODE_TYPE_INFO(IntTypeNode, TypeNode);
};
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("dtype", &type); }

RELAY_DEFINE_TYPE(IntType, IntTypeNode);
TVM_DLL static BaseType make(HalideIR::Type type);

class BoolType;
/** Constructing an unsigned integer type */
TVM_DLL static BaseType Int(int bits, int lanes = 1);

/*! \brief The type of boolean values. */
class BoolTypeNode : public TypeNode {
public:
BoolTypeNode() {}
/** Constructing an unsigned integer type */
TVM_DLL static BaseType UInt(int bits, int lanes = 1);

void VisitAttrs(tvm::AttrVisitor* v) final {}
/** Construct a floating-point type */
TVM_DLL static BaseType Float(int bits, int lanes = 1);

TVM_DLL static BoolType make();
/** Construct a boolean type */
TVM_DLL static BaseType Bool(int lanes = 1);

static constexpr const char* _type_key = "nnvm.BoolType";
TVM_DECLARE_NODE_TYPE_INFO(BoolTypeNode, TypeNode);
static constexpr const char* _type_key = "nnvm.BaseType";
TVM_DECLARE_NODE_TYPE_INFO(BaseTypeNode, TypeNode);
};

RELAY_DEFINE_TYPE(BoolType, BoolTypeNode);
RELAY_DEFINE_TYPE(BaseType, BaseTypeNode);

class TypeVar;

Expand Down Expand Up @@ -362,13 +357,13 @@ class TensorLit;
/*! \brief Tensor literal [t1, [x1, ..., xn]]. */
class TensorLitNode : public ExprNode {
public:
tvm::Array<tvm::NodeRef> data;
tvm::Array<Expr> data;

TensorLitNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }

TVM_DLL static TensorLit make(tvm::Array<tvm::NodeRef> data);
TVM_DLL static TensorLit make(tvm::Array<Expr> data);

static constexpr const char* _type_key = "nnvm.TensorLit";
TVM_DECLARE_NODE_TYPE_INFO(TensorLitNode, ExprNode);
Expand All @@ -381,13 +376,13 @@ class ProductLit;
/*! \brief Product literal (x, ... y). */
class ProductLitNode : public ExprNode {
public:
tvm::Array<tvm::NodeRef> fields;
tvm::Array<Expr> fields;

ProductLitNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }

TVM_DLL static ProductLit make(tvm::Array<tvm::NodeRef> data);
TVM_DLL static ProductLit make(tvm::Array<Expr> data);

static constexpr const char* _type_key = "nnvm.ProductLit";
TVM_DECLARE_NODE_TYPE_INFO(ProductLitNode, ExprNode);
Expand Down Expand Up @@ -585,10 +580,7 @@ RELAY_DEFINE_EXPR(Debug, DebugNode)

class UnaryOp;

enum UOp : int {
NEG = 0,
SQ = 1
};
enum UOp : int { NEG = 0, SQ = 1 };

/*! \brief Unary Operator. */
class UnaryOpNode : public ExprNode {
Expand Down Expand Up @@ -879,9 +871,7 @@ struct hash<nnvm::relay::LocalId> {
* \param id global id.
* \return hash code.
*/
size_t operator()(const nnvm::relay::LocalId& id) const {
return id.hash();
}
size_t operator()(const nnvm::relay::LocalId& id) const { return id.hash(); }
};

} // namespace std
Expand Down
32 changes: 16 additions & 16 deletions relay/include/relay/reverse_ad.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ namespace relay {
struct ReverseAD : ExprFunctor<Expr(const Expr& n)> {
ReverseAD() {}
Expr AD(const Expr& expr);
Expr VisitExpr_(const LocalId& op);
Expr VisitExpr_(const GlobalId& op);
Expr VisitExpr_(const IntrinsicId& op);
Expr VisitExpr_(const FloatLit& op);
Expr VisitExpr_(const BoolLit& op);
Expr VisitExpr_(const IntLit& op);
Expr VisitExpr_(const TensorLit& op);
Expr VisitExpr_(const ProductLit& op);
Expr VisitExpr_(const Cast& op);
Expr VisitExpr_(const Param& op);
Expr VisitExpr_(const Function& op);
Expr VisitExpr_(const Call& op);
Expr VisitExpr_(const Debug& op);
Expr VisitExpr_(const UnaryOp& op);
Expr VisitExpr_(const BinaryOp& op);
Expr VisitExpr_(const Let& op);
Expr VisitExpr_(const LocalIdNode* op);
Expr VisitExpr_(const GlobalIdNode* op);
Expr VisitExpr_(const IntrinsicIdNode* op);
Expr VisitExpr_(const FloatLitNode* op);
Expr VisitExpr_(const BoolLitNode* op);
Expr VisitExpr_(const IntLitNode* op);
Expr VisitExpr_(const TensorLitNode* op);
Expr VisitExpr_(const ProductLitNode* op);
Expr VisitExpr_(const CastNode* op);
Expr VisitExpr_(const ParamNode* op);
Expr VisitExpr_(const FunctionNode* op);
Expr VisitExpr_(const CallNode* op);
Expr VisitExpr_(const DebugNode* op);
Expr VisitExpr_(const UnaryOpNode* op);
Expr VisitExpr_(const BinaryOpNode* op);
Expr VisitExpr_(const LetNode* op);
};

} // namespace relay
Expand Down
Loading

0 comments on commit 7babba8

Please sign in to comment.