Skip to content

Commit

Permalink
Add missing tests and modify attributes (#5)
Browse files Browse the repository at this point in the history
Restores the tests which were lost in the repo port, and makes it possible for conv2d to typecheck,
as well as some integration tests.

* Add back python tests that were missing from the repo move (history lost, sorry)

* Ensure shape evaluator doesn't trip up on type vars or type ids

* Add tests for shape evaluator when faced with type var or type id

* Use Strings as key for Attributes, repair tests and uses

* Add preliminary information for conv2d operator

* Add test of attr propagation

* Ensure attributes hash by string value rather than string pointer identity (tests still failing though, idk why)

* Adjust shape equality checks to use ordinary visitor, as nested type id information was not transferring

* Repair integration test by using std::unordered_map for attrs, add cases

* Add regression test for alpha-eq comparison of type IDs across nested shapes (was losing the type id equality map in alpha_eq

* Add more integration test variants

* Correct import names in tests

* Add clarifying comment to typechecker

* Missing paren in operators.py

* Correct python source directory for mypy
  • Loading branch information
slyubomirsky authored and jroesch committed Aug 16, 2018
1 parent 3d1b67b commit 0def9c9
Show file tree
Hide file tree
Showing 32 changed files with 3,151 additions and 67 deletions.
2 changes: 1 addition & 1 deletion relay/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ cyclean:
lint: pylint cpplint

mypy:
python3.6 -m mypy --ignore-missing-imports python/tvm/relay tests/python/relay/
python3.6 -m mypy --ignore-missing-imports python/relay tests/python/relay/

cpplint:
python3.6 dmlc-core/scripts/lint.py relay cpp include src
Expand Down
53 changes: 50 additions & 3 deletions relay/include/tvm/relay/ir/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,41 @@ struct Expr : public NodeRef {
using ContainerType = ExprNode;
};

struct StringNode;

/*! \brief an entry that represents output data from a node */
class String : public NodeRef {
public:
/*! \brief default constructor, used internally */
String() {}
explicit String(std::shared_ptr<tvm::Node> n) : NodeRef(n) {}
inline const StringNode* operator->() const;

/*! \brief specify container node */
using ContainerType = StringNode;
};

struct StringNode : public ExprNode {
public:
std::string name;

StringNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
v->Visit("name", &name);
}

TVM_DLL static String make(std::string name);

static constexpr const char* _type_key = "nnvm.String";
TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode);
};

inline const StringNode* String::operator->() const {
return static_cast<const StringNode*>(node_.get());
}

class LocalId;

/*! \brief A LocalId from the node's current type to target type. */
Expand Down Expand Up @@ -194,21 +229,33 @@ class GlobalIdNode : public ExprNode {

RELAY_DEFINE_NODE_REF(GlobalId, GlobalIdNode, Expr);

struct StringHash {
size_t operator()(const String &key) const {
return std::hash<std::string>() (key->name);
}
};

struct StringEqual {
bool operator()(const String &lhs, const String &rhs) const {
return lhs->name == rhs->name;
}
};

class Attributes;

/*! \brief A floating point value. */
class AttributesNode : public Node {
public:
tvm::Map<LocalId, Expr> attributes;
std::unordered_map<String, Expr, StringHash, StringEqual> attributes;

AttributesNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
v->Visit("attributes", &attributes);
}

TVM_DLL static Attributes make(tvm::Map<LocalId, Expr> attributes);
TVM_DLL static Attributes make(std::unordered_map<String, Expr,
StringHash, StringEqual> attributes);

static constexpr const char* _type_key = "nnvm.Attributes";
TVM_DECLARE_NODE_TYPE_INFO(AttributesNode, Node);
Expand Down
35 changes: 0 additions & 35 deletions relay/include/tvm/relay/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,6 @@ namespace relay {

typedef HalideIR::Type HType;

struct StringNode;

/*! \brief an entry that represents output data from a node */
class String : public NodeRef {
public:
/*! \brief default constructor, used internally */
String() {}
explicit String(std::shared_ptr<tvm::Node> n) : NodeRef(n) {}
inline const StringNode* operator->() const;

/*! \brief specify container node */
using ContainerType = StringNode;
};

struct StringNode : public ExprNode {
public:
std::string name;

StringNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
v->Visit("name", &name);
}

TVM_DLL static String make(std::string name);

static constexpr const char* _type_key = "nnvm.String";
TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode);
};

inline const StringNode* String::operator->() const {
return static_cast<const StringNode*>(node_.get());
}

class FloatLit;

/*! \brief Floating point literal `0.0`, `5e10`. */
Expand Down
4 changes: 2 additions & 2 deletions relay/include/tvm/relay/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class ShapeAttr;
/*! \brief Shape singleton that captures the value of an attribute */
class ShapeAttrNode : public TypeNode {
public:
LocalId id;
String id;

ShapeAttrNode() {}

Expand All @@ -280,7 +280,7 @@ class ShapeAttrNode : public TypeNode {
v->Visit("span", &span);
}

TVM_DLL static ShapeAttr make(LocalId id);
TVM_DLL static ShapeAttr make(String id);

static constexpr const char* _type_key = "nnvm.ShapeAttr";
TVM_DECLARE_NODE_TYPE_INFO(ShapeAttrNode, TypeNode);
Expand Down
4 changes: 1 addition & 3 deletions relay/python/relay/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def set_span(self, span: Span):

@register_nnvm_node
class Attributes(NodeBase):
def __getitem__(self, index):
return self.attributes[index]

pass

class Value(NodeBase):
"""Base class of all values.
Expand Down
6 changes: 5 additions & 1 deletion relay/python/relay/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

@register_nnvm_node
class String(Expr):
value: str
name: str

# need to define hash to use in maps (e.g., for attrs)
def __hash__(self):
return self.name.__hash__()


@register_nnvm_node
Expand Down
4 changes: 2 additions & 2 deletions relay/python/relay/make.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def TypeVar() -> ir.Type: ...
def PlaceholderType() -> ir.Type: ...
def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ...
def ShapeSingleton(value: int) -> ir.ShapeSingleton: ...
def ShapeAttr(id: ir.LocalId) -> ir.ShapeAttr: ...
def ShapeAttr(id: ir.String) -> ir.ShapeAttr: ...
def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ...
def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ...

Expand All @@ -46,7 +46,7 @@ def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ...
def ProductLit(fields: List[ir.Expr]) -> ir.Expr: ...
def BoolLit(value: bool) -> ir.BoolLit: ...
def String(value: str) -> ir.String: ...
def Attributes(attrs: Dict[ir.LocalId, ir.Expr]) -> ir.Attributes: ...
def Attributes(attrs: Dict[ir.String, ir.Expr]) -> ir.Attributes: ...
def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ...
def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ...
def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ...
Expand Down
53 changes: 52 additions & 1 deletion relay/python/relay/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import topi
from relay.env import Environment
import relay.ir as ir
from relay.make import Operator, IntrinsicId, TypeId, TensorType, FloatType
from relay.make import String, Operator, IntrinsicId, TypeId, TensorType, FloatType
from relay.make import TypeQuantifier, TypeArrow, ProductType
from relay.make import ShapeAttr, ShapeBinaryOp, ShapeProjection, ShapeSingleton, ShapeSeq

# TODO(@jroesch): Fix up my type
__operator_registry__: Dict[str, Any] = {}
Expand Down Expand Up @@ -128,6 +129,15 @@ def broadcast_mul_compiler(func_ty: ir.Type) -> Any:
module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="broadcast_mul_compiler")
return module.get_function("broadcast_mul_compiler")

# TODO(@jroesch): ensure this interfaces correctly
# note that the type provided doesn't handle padding
# feel free to assume some default behavior
def conv2d_compiler(func_ty: ir.Type) -> Any:
Inputs, ret_ty = func_ty_to_placeholders(func_ty)
Output = topi.nn.conv2d(*Inputs)
schedule = tvm.create_schedule(Output.op)
module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="conv2d_compiler")
return module.get_function("conv2d_compiler")

def initialize_operators(env) -> None:
"""Initialize the default set of operators for the system, this will populate
Expand Down Expand Up @@ -175,3 +185,44 @@ def initialize_operators(env) -> None:
bmul_type = TypeQuantifier(shape, TypeArrow(ProductType([in_out_type, in_out_type]), in_out_type))
# TODO: reverse mode
register_op(env, 'broadcast_mul', bmul_type, broadcast_mul_compiler)

# Conv2d
# input: [batch, in_channel, in_height, in_width]
# filter: [filter_height, filter_width, in_channel, num_filter]
# output shape: [out_height, out_width, num_filter, batch]
# out_height = (in_height - filter_h)/stride_h + 1
# out_width = (in_width - filter_w)/stride_w + 1
stride_h = ShapeAttr(String("stride_h"))
stride_w = ShapeAttr(String("stride_w"))
btvar = TypeId("bt", Kind.BaseType)
input_shape = TypeId("input_shape", Kind.Shape)
filter_shape = TypeId("filter_shape", Kind.Shape)
output_shape = ShapeSeq([
ShapeBinaryOp(ShapeOp.SHPLUS,
ShapeBinaryOp(ShapeOp.SHDIV,
ShapeBinaryOp(ShapeOp.SHSUB,
ShapeProjection(input_shape, 2),
ShapeProjection(filter_shape, 0)),
stride_h),
ShapeSingleton(1)),
ShapeBinaryOp(ShapeOp.SHPLUS,
ShapeBinaryOp(ShapeOp.SHDIV,
ShapeBinaryOp(ShapeOp.SHSUB,
ShapeProjection(input_shape, 3),
ShapeProjection(filter_shape, 1)),
stride_w),
ShapeSingleton(1)),
ShapeProjection(filter_shape, 3),
ShapeProjection(input_shape, 0)
])
conv2d_type = TypeQuantifier(
btvar,
TypeQuantifier(
input_shape,
TypeQuantifier(
filter_shape,
TypeArrow(ProductType([TensorType(btvar, input_shape), TensorType(btvar, filter_shape)],
TensorType(btvar, output_shape)
)))))
# TODO: reverse mode
register_op(env, 'conv2d', conv2d_type, conv2d_compiler)
11 changes: 6 additions & 5 deletions relay/src/tvm/relay/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {

void VisitType_(const ShapeAttrNode *sn1, const Type &t2) override {
if (const ShapeAttrNode *sn2 = t2.as<ShapeAttrNode>()) {
// require exact quality of identifiers
equal = equal && (sn1->id == sn2->id);
// check equality of names
equal = equal && (sn1->id->name == sn2->id->name);
} else {
equal = false;
}
Expand All @@ -418,7 +418,7 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
auto size = shape1->shapes.size();
for (size_t i = 0U; i < size; i++) {
if (!equal) { return; }
equal = equal && alpha_eq(shape1->shapes[i], shape2->shapes[i]);
this->VisitType(shape1->shapes[i], shape2->shapes[i]);
}
} else {
equal = false;
Expand All @@ -433,7 +433,7 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
return;
}
ShapeProjection proj2 = GetRef<ShapeProjection>(spn2);
equal = equal && alpha_eq(proj1->shape, proj2->shape);
this->VisitType(proj1->shape, proj2->shape);
} else {
equal = false;
}
Expand All @@ -447,7 +447,8 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
return;
}
ShapeBinaryOp op2 = GetRef<ShapeBinaryOp>(sbn2);
equal = equal && alpha_eq(op1->left, op2->left) && alpha_eq(op1->right, op2->right);
this->VisitType(op1->left, op2->left);
this->VisitType(op1->right, op2->right);
} else {
equal = false;
}
Expand Down
31 changes: 25 additions & 6 deletions relay/src/tvm/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ TVM_REGISTER_API("nnvm.make.String")
.set_body([](TVMArgs args,
TVMRetValue *ret) { *ret = StringNode::make(args[0]); });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<StringNode>([](const StringNode *node, tvm::IRPrinter *p) {
p->stream << "String(" << node->name << ")";
});

FloatLit FloatLitNode::make(double value) {
std::shared_ptr<FloatLitNode> n = std::make_shared<FloatLitNode>();
n->value = std::move(value);
Expand Down Expand Up @@ -262,7 +267,8 @@ Call CallNode::make(Expr fn, tvm::Array<Expr> args, Attributes attrs) {

TVM_REGISTER_API("nnvm.make.Call").set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() < 3) {
Attributes attrs = AttributesNode::make(tvm::Map<LocalId, Expr>());
Attributes attrs = AttributesNode::make(
std::unordered_map<String, Expr, StringHash, StringEqual>());
*ret = CallNode::make(args[0], args[1], attrs);
} else {
*ret = CallNode::make(args[0], args[1], args[2]);
Expand Down Expand Up @@ -441,18 +447,31 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->value << ", " << node->body;
});

static void validate_attributes(tvm::Map<LocalId, Expr> attrs) { return; }
static void
validate_attributes(tvm::Map<String, Expr, StringHash, StringEqual> attrs) {
return;
}

Attributes AttributesNode::make(tvm::Map<LocalId, Expr> attrs) {
Attributes AttributesNode::make(
std::unordered_map<String, Expr, StringHash, StringEqual> attrs) {
std::shared_ptr<AttributesNode> n = std::make_shared<AttributesNode>();
validate_attributes(attrs);
n->attributes = std::move(attrs);
n->attributes = attrs;
return Attributes(n);
}

TVM_REGISTER_API("nnvm.make.Attributes")
.set_body([](TVMArgs args,
TVMRetValue *ret) { *ret = AttributesNode::make(args[0]); });
.set_body([](TVMArgs args, TVMRetValue *ret) {
// ensure attrs are moved to appropriate map
tvm::Map<String, Expr> map = args[0];
std::unordered_map<String, Expr, StringHash, StringEqual> attrs;

for (auto p : map) {
attrs[p.first] = p.second;
}

*ret = AttributesNode::make(attrs);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<AttributesNode>([](const AttributesNode *node,
Expand Down
2 changes: 1 addition & 1 deletion relay/src/tvm/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "ShapeSingletonNode(" << node->value << ")";
});

ShapeAttr ShapeAttrNode::make(LocalId id) {
ShapeAttr ShapeAttrNode::make(String id) {
auto n = std::make_shared<ShapeAttrNode>();
n->id = id;
return ShapeAttr(n);
Expand Down
4 changes: 4 additions & 0 deletions relay/src/tvm/relay/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ struct TypeDocifier : TypeFunctor<Doc(const Type &n)> {
return TEXT(std::to_string(op->value));
}

Doc VisitType_(const ShapeAttrNode *op) override {
return TEXT(op->id->name);
}

Doc VisitType_(const ShapeSeqNode *op) override {
Doc shape = NIL();
for (size_t i = 0; i < op->shapes.size(); ++i) {
Expand Down
6 changes: 3 additions & 3 deletions relay/src/tvm/relay/reverse_ad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ struct ReverseAD : ExprFunctor<Expr(const Expr& n, const Expr& bp)> {
args.push_back(AD(arg, bp));
}

tvm::Map<LocalId, Expr> attr(op->attrs->attributes);
for (const std::pair<LocalId, Expr>& p : op->attrs->attributes) {
attr.Set(p.first, AD(attr[p.first], bp));
std::unordered_map<String, Expr, StringHash, StringEqual> attr(op->attrs->attributes);
for (const std::pair<String, Expr>& p : op->attrs->attributes) {
attr[p.first] = AD(attr[p.first], bp);
}

if (const IntrinsicIdNode* iin = op->fn.as<IntrinsicIdNode>()) {
Expand Down
Loading

0 comments on commit 0def9c9

Please sign in to comment.