From 0def9c9fb43e4d012bf6bc06c5e95ad811eede5b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 8 Jun 2018 01:59:37 -0700 Subject: [PATCH] Add missing tests and modify attributes (#5) Restores the tests which were lost in the repo port, and makes it possible for conv2d to typecheck, as well as some integration tests. * Add back python tests that were missing from the repo move (history lost, sorry) * Ensure shape evaluator doesn't trip up on type vars or type ids * Add tests for shape evaluator when faced with type var or type id * Use Strings as key for Attributes, repair tests and uses * Add preliminary information for conv2d operator * Add test of attr propagation * Ensure attributes hash by string value rather than string pointer identity (tests still failing though, idk why) * Adjust shape equality checks to use ordinary visitor, as nested type id information was not transferring * Repair integration test by using std::unordered_map for attrs, add cases * Add regression test for alpha-eq comparison of type IDs across nested shapes (was losing the type id equality map in alpha_eq * Add more integration test variants * Correct import names in tests * Add clarifying comment to typechecker * Missing paren in operators.py * Correct python source directory for mypy --- relay/Makefile | 2 +- relay/include/tvm/relay/ir/base.h | 53 +- relay/include/tvm/relay/ir/expr.h | 35 - relay/include/tvm/relay/ir/type.h | 4 +- relay/python/relay/ir/base.py | 4 +- relay/python/relay/ir/expr.py | 6 +- relay/python/relay/make.pyi | 4 +- relay/python/relay/operators.py | 53 +- relay/src/tvm/relay/alpha_eq.cc | 11 +- relay/src/tvm/relay/ir/expr.cc | 31 +- relay/src/tvm/relay/ir/type.cc | 2 +- relay/src/tvm/relay/pretty_printer.cc | 4 + relay/src/tvm/relay/reverse_ad.cc | 6 +- relay/src/tvm/relay/typeck/shape_evaluator.cc | 11 + relay/src/tvm/relay/typeck/typechecker.cc | 18 +- relay/src/tvm/relay/typeck/unifier.cc | 2 +- relay/tests/python/__init__.py | 0 relay/tests/python/test_alpha_eq.py | 617 ++++++++++++++++++ relay/tests/python/test_ast.py | 236 +++++++ relay/tests/python/test_decorator.py | 416 ++++++++++++ relay/tests/python/test_eval.py | 152 +++++ relay/tests/python/test_forward_ad.py | 20 + relay/tests/python/test_grad_descent.py | 31 + relay/tests/python/test_kindchecker.py | 128 ++++ relay/tests/python/test_pretty_print.py | 32 + relay/tests/python/test_reverse_ad.py | 21 + relay/tests/python/test_shape_evaluator.py | 153 +++++ relay/tests/python/test_softmax.py | 35 + relay/tests/python/test_span.py | 33 + relay/tests/python/test_tyck.py | 404 ++++++++++++ relay/tests/python/test_unifier.py | 480 ++++++++++++++ relay/tests/python/test_visitor.py | 214 ++++++ 32 files changed, 3151 insertions(+), 67 deletions(-) create mode 100644 relay/tests/python/__init__.py create mode 100644 relay/tests/python/test_alpha_eq.py create mode 100644 relay/tests/python/test_ast.py create mode 100644 relay/tests/python/test_decorator.py create mode 100644 relay/tests/python/test_eval.py create mode 100644 relay/tests/python/test_forward_ad.py create mode 100644 relay/tests/python/test_grad_descent.py create mode 100644 relay/tests/python/test_kindchecker.py create mode 100644 relay/tests/python/test_pretty_print.py create mode 100644 relay/tests/python/test_reverse_ad.py create mode 100644 relay/tests/python/test_shape_evaluator.py create mode 100644 relay/tests/python/test_softmax.py create mode 100644 relay/tests/python/test_span.py create mode 100644 relay/tests/python/test_tyck.py create mode 100644 relay/tests/python/test_unifier.py create mode 100644 relay/tests/python/test_visitor.py diff --git a/relay/Makefile b/relay/Makefile index 99b42dcb4ce35..d0421fede8522 100644 --- a/relay/Makefile +++ b/relay/Makefile @@ -75,7 +75,7 @@ cyclean: lint: pylint cpplint mypy: - python3.6 -m mypy --ignore-missing-imports python/tvm/relay tests/python/relay/ + python3.6 -m mypy --ignore-missing-imports python/relay tests/python/relay/ cpplint: python3.6 dmlc-core/scripts/lint.py relay cpp include src diff --git a/relay/include/tvm/relay/ir/base.h b/relay/include/tvm/relay/ir/base.h index 56b67faabb144..8cb405b0c2cb1 100644 --- a/relay/include/tvm/relay/ir/base.h +++ b/relay/include/tvm/relay/ir/base.h @@ -150,6 +150,41 @@ struct Expr : public NodeRef { using ContainerType = ExprNode; }; +struct StringNode; + +/*! \brief an entry that represents output data from a node */ +class String : public NodeRef { + public: + /*! \brief default constructor, used internally */ + String() {} + explicit String(std::shared_ptr n) : NodeRef(n) {} + inline const StringNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = StringNode; +}; + +struct StringNode : public ExprNode { + public: + std::string name; + + StringNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("span", &span); + v->Visit("name", &name); + } + + TVM_DLL static String make(std::string name); + + static constexpr const char* _type_key = "nnvm.String"; + TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode); +}; + +inline const StringNode* String::operator->() const { + return static_cast(node_.get()); +} + class LocalId; /*! \brief A LocalId from the node's current type to target type. */ @@ -194,21 +229,33 @@ class GlobalIdNode : public ExprNode { RELAY_DEFINE_NODE_REF(GlobalId, GlobalIdNode, Expr); +struct StringHash { + size_t operator()(const String &key) const { + return std::hash() (key->name); + } +}; + +struct StringEqual { + bool operator()(const String &lhs, const String &rhs) const { + return lhs->name == rhs->name; + } +}; + class Attributes; /*! \brief A floating point value. */ class AttributesNode : public Node { public: - tvm::Map attributes; + std::unordered_map attributes; AttributesNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("span", &span); - v->Visit("attributes", &attributes); } - TVM_DLL static Attributes make(tvm::Map attributes); + TVM_DLL static Attributes make(std::unordered_map attributes); static constexpr const char* _type_key = "nnvm.Attributes"; TVM_DECLARE_NODE_TYPE_INFO(AttributesNode, Node); diff --git a/relay/include/tvm/relay/ir/expr.h b/relay/include/tvm/relay/ir/expr.h index b46ad4628be2f..9ef59881e4dff 100644 --- a/relay/include/tvm/relay/ir/expr.h +++ b/relay/include/tvm/relay/ir/expr.h @@ -17,41 +17,6 @@ namespace relay { typedef HalideIR::Type HType; -struct StringNode; - -/*! \brief an entry that represents output data from a node */ -class String : public NodeRef { - public: - /*! \brief default constructor, used internally */ - String() {} - explicit String(std::shared_ptr n) : NodeRef(n) {} - inline const StringNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = StringNode; -}; - -struct StringNode : public ExprNode { - public: - std::string name; - - StringNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("span", &span); - v->Visit("name", &name); - } - - TVM_DLL static String make(std::string name); - - static constexpr const char* _type_key = "nnvm.String"; - TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode); -}; - -inline const StringNode* String::operator->() const { - return static_cast(node_.get()); -} - class FloatLit; /*! \brief Floating point literal `0.0`, `5e10`. */ diff --git a/relay/include/tvm/relay/ir/type.h b/relay/include/tvm/relay/ir/type.h index c7f7efbd4050f..a9d0639f74e5b 100644 --- a/relay/include/tvm/relay/ir/type.h +++ b/relay/include/tvm/relay/ir/type.h @@ -271,7 +271,7 @@ class ShapeAttr; /*! \brief Shape singleton that captures the value of an attribute */ class ShapeAttrNode : public TypeNode { public: - LocalId id; + String id; ShapeAttrNode() {} @@ -280,7 +280,7 @@ class ShapeAttrNode : public TypeNode { v->Visit("span", &span); } - TVM_DLL static ShapeAttr make(LocalId id); + TVM_DLL static ShapeAttr make(String id); static constexpr const char* _type_key = "nnvm.ShapeAttr"; TVM_DECLARE_NODE_TYPE_INFO(ShapeAttrNode, TypeNode); diff --git a/relay/python/relay/ir/base.py b/relay/python/relay/ir/base.py index a7b30d8819eb9..2748dcf953573 100644 --- a/relay/python/relay/ir/base.py +++ b/relay/python/relay/ir/base.py @@ -88,9 +88,7 @@ def set_span(self, span: Span): @register_nnvm_node class Attributes(NodeBase): - def __getitem__(self, index): - return self.attributes[index] - + pass class Value(NodeBase): """Base class of all values. diff --git a/relay/python/relay/ir/expr.py b/relay/python/relay/ir/expr.py index 38826979dbd64..edfa7bd12d558 100644 --- a/relay/python/relay/ir/expr.py +++ b/relay/python/relay/ir/expr.py @@ -7,7 +7,11 @@ @register_nnvm_node class String(Expr): - value: str + name: str + + # need to define hash to use in maps (e.g., for attrs) + def __hash__(self): + return self.name.__hash__() @register_nnvm_node diff --git a/relay/python/relay/make.pyi b/relay/python/relay/make.pyi index c4f692a209793..6d1bee3bdab22 100644 --- a/relay/python/relay/make.pyi +++ b/relay/python/relay/make.pyi @@ -29,7 +29,7 @@ def TypeVar() -> ir.Type: ... def PlaceholderType() -> ir.Type: ... def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ... def ShapeSingleton(value: int) -> ir.ShapeSingleton: ... -def ShapeAttr(id: ir.LocalId) -> ir.ShapeAttr: ... +def ShapeAttr(id: ir.String) -> ir.ShapeAttr: ... def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ... def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ... @@ -46,7 +46,7 @@ def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ... def ProductLit(fields: List[ir.Expr]) -> ir.Expr: ... def BoolLit(value: bool) -> ir.BoolLit: ... def String(value: str) -> ir.String: ... -def Attributes(attrs: Dict[ir.LocalId, ir.Expr]) -> ir.Attributes: ... +def Attributes(attrs: Dict[ir.String, ir.Expr]) -> ir.Attributes: ... def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ... def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ... def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ... diff --git a/relay/python/relay/operators.py b/relay/python/relay/operators.py index 74a9a51954cd0..c44069dc51d4c 100644 --- a/relay/python/relay/operators.py +++ b/relay/python/relay/operators.py @@ -5,8 +5,9 @@ import topi from relay.env import Environment import relay.ir as ir -from relay.make import Operator, IntrinsicId, TypeId, TensorType, FloatType +from relay.make import String, Operator, IntrinsicId, TypeId, TensorType, FloatType from relay.make import TypeQuantifier, TypeArrow, ProductType +from relay.make import ShapeAttr, ShapeBinaryOp, ShapeProjection, ShapeSingleton, ShapeSeq # TODO(@jroesch): Fix up my type __operator_registry__: Dict[str, Any] = {} @@ -128,6 +129,15 @@ def broadcast_mul_compiler(func_ty: ir.Type) -> Any: module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="broadcast_mul_compiler") return module.get_function("broadcast_mul_compiler") +# TODO(@jroesch): ensure this interfaces correctly +# note that the type provided doesn't handle padding +# feel free to assume some default behavior +def conv2d_compiler(func_ty: ir.Type) -> Any: + Inputs, ret_ty = func_ty_to_placeholders(func_ty) + Output = topi.nn.conv2d(*Inputs) + schedule = tvm.create_schedule(Output.op) + module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="conv2d_compiler") + return module.get_function("conv2d_compiler") def initialize_operators(env) -> None: """Initialize the default set of operators for the system, this will populate @@ -175,3 +185,44 @@ def initialize_operators(env) -> None: bmul_type = TypeQuantifier(shape, TypeArrow(ProductType([in_out_type, in_out_type]), in_out_type)) # TODO: reverse mode register_op(env, 'broadcast_mul', bmul_type, broadcast_mul_compiler) + + # Conv2d + # input: [batch, in_channel, in_height, in_width] + # filter: [filter_height, filter_width, in_channel, num_filter] + # output shape: [out_height, out_width, num_filter, batch] + # out_height = (in_height - filter_h)/stride_h + 1 + # out_width = (in_width - filter_w)/stride_w + 1 + stride_h = ShapeAttr(String("stride_h")) + stride_w = ShapeAttr(String("stride_w")) + btvar = TypeId("bt", Kind.BaseType) + input_shape = TypeId("input_shape", Kind.Shape) + filter_shape = TypeId("filter_shape", Kind.Shape) + output_shape = ShapeSeq([ + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeBinaryOp(ShapeOp.SHDIV, + ShapeBinaryOp(ShapeOp.SHSUB, + ShapeProjection(input_shape, 2), + ShapeProjection(filter_shape, 0)), + stride_h), + ShapeSingleton(1)), + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeBinaryOp(ShapeOp.SHDIV, + ShapeBinaryOp(ShapeOp.SHSUB, + ShapeProjection(input_shape, 3), + ShapeProjection(filter_shape, 1)), + stride_w), + ShapeSingleton(1)), + ShapeProjection(filter_shape, 3), + ShapeProjection(input_shape, 0) + ]) + conv2d_type = TypeQuantifier( + btvar, + TypeQuantifier( + input_shape, + TypeQuantifier( + filter_shape, + TypeArrow(ProductType([TensorType(btvar, input_shape), TensorType(btvar, filter_shape)], + TensorType(btvar, output_shape) + ))))) + # TODO: reverse mode + register_op(env, 'conv2d', conv2d_type, conv2d_compiler) diff --git a/relay/src/tvm/relay/alpha_eq.cc b/relay/src/tvm/relay/alpha_eq.cc index edd27d05af556..9c717f042b17c 100644 --- a/relay/src/tvm/relay/alpha_eq.cc +++ b/relay/src/tvm/relay/alpha_eq.cc @@ -399,8 +399,8 @@ struct TypeAlphaEq : TypeVisitor { void VisitType_(const ShapeAttrNode *sn1, const Type &t2) override { if (const ShapeAttrNode *sn2 = t2.as()) { - // require exact quality of identifiers - equal = equal && (sn1->id == sn2->id); + // check equality of names + equal = equal && (sn1->id->name == sn2->id->name); } else { equal = false; } @@ -418,7 +418,7 @@ struct TypeAlphaEq : TypeVisitor { auto size = shape1->shapes.size(); for (size_t i = 0U; i < size; i++) { if (!equal) { return; } - equal = equal && alpha_eq(shape1->shapes[i], shape2->shapes[i]); + this->VisitType(shape1->shapes[i], shape2->shapes[i]); } } else { equal = false; @@ -433,7 +433,7 @@ struct TypeAlphaEq : TypeVisitor { return; } ShapeProjection proj2 = GetRef(spn2); - equal = equal && alpha_eq(proj1->shape, proj2->shape); + this->VisitType(proj1->shape, proj2->shape); } else { equal = false; } @@ -447,7 +447,8 @@ struct TypeAlphaEq : TypeVisitor { return; } ShapeBinaryOp op2 = GetRef(sbn2); - equal = equal && alpha_eq(op1->left, op2->left) && alpha_eq(op1->right, op2->right); + this->VisitType(op1->left, op2->left); + this->VisitType(op1->right, op2->right); } else { equal = false; } diff --git a/relay/src/tvm/relay/ir/expr.cc b/relay/src/tvm/relay/ir/expr.cc index 9db886d1d91e8..24c12118853f9 100644 --- a/relay/src/tvm/relay/ir/expr.cc +++ b/relay/src/tvm/relay/ir/expr.cc @@ -22,6 +22,11 @@ TVM_REGISTER_API("nnvm.make.String") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = StringNode::make(args[0]); }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const StringNode *node, tvm::IRPrinter *p) { + p->stream << "String(" << node->name << ")"; + }); + FloatLit FloatLitNode::make(double value) { std::shared_ptr n = std::make_shared(); n->value = std::move(value); @@ -262,7 +267,8 @@ Call CallNode::make(Expr fn, tvm::Array args, Attributes attrs) { TVM_REGISTER_API("nnvm.make.Call").set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() < 3) { - Attributes attrs = AttributesNode::make(tvm::Map()); + Attributes attrs = AttributesNode::make( + std::unordered_map()); *ret = CallNode::make(args[0], args[1], attrs); } else { *ret = CallNode::make(args[0], args[1], args[2]); @@ -441,18 +447,31 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->value << ", " << node->body; }); -static void validate_attributes(tvm::Map attrs) { return; } +static void +validate_attributes(tvm::Map attrs) { + return; +} -Attributes AttributesNode::make(tvm::Map attrs) { +Attributes AttributesNode::make( + std::unordered_map attrs) { std::shared_ptr n = std::make_shared(); validate_attributes(attrs); - n->attributes = std::move(attrs); + n->attributes = attrs; return Attributes(n); } TVM_REGISTER_API("nnvm.make.Attributes") - .set_body([](TVMArgs args, - TVMRetValue *ret) { *ret = AttributesNode::make(args[0]); }); + .set_body([](TVMArgs args, TVMRetValue *ret) { + // ensure attrs are moved to appropriate map + tvm::Map map = args[0]; + std::unordered_map attrs; + + for (auto p : map) { + attrs[p.first] = p.second; + } + + *ret = AttributesNode::make(attrs); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const AttributesNode *node, diff --git a/relay/src/tvm/relay/ir/type.cc b/relay/src/tvm/relay/ir/type.cc index 0c4b6a041bf28..db6f36e85e70b 100644 --- a/relay/src/tvm/relay/ir/type.cc +++ b/relay/src/tvm/relay/ir/type.cc @@ -292,7 +292,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ShapeSingletonNode(" << node->value << ")"; }); -ShapeAttr ShapeAttrNode::make(LocalId id) { +ShapeAttr ShapeAttrNode::make(String id) { auto n = std::make_shared(); n->id = id; return ShapeAttr(n); diff --git a/relay/src/tvm/relay/pretty_printer.cc b/relay/src/tvm/relay/pretty_printer.cc index 2a9dc48a7e516..6d430885a1565 100644 --- a/relay/src/tvm/relay/pretty_printer.cc +++ b/relay/src/tvm/relay/pretty_printer.cc @@ -252,6 +252,10 @@ struct TypeDocifier : TypeFunctor { return TEXT(std::to_string(op->value)); } + Doc VisitType_(const ShapeAttrNode *op) override { + return TEXT(op->id->name); + } + Doc VisitType_(const ShapeSeqNode *op) override { Doc shape = NIL(); for (size_t i = 0; i < op->shapes.size(); ++i) { diff --git a/relay/src/tvm/relay/reverse_ad.cc b/relay/src/tvm/relay/reverse_ad.cc index ff44f5a1c4199..e2fc6c17b77f3 100644 --- a/relay/src/tvm/relay/reverse_ad.cc +++ b/relay/src/tvm/relay/reverse_ad.cc @@ -112,9 +112,9 @@ struct ReverseAD : ExprFunctor { args.push_back(AD(arg, bp)); } - tvm::Map attr(op->attrs->attributes); - for (const std::pair& p : op->attrs->attributes) { - attr.Set(p.first, AD(attr[p.first], bp)); + std::unordered_map attr(op->attrs->attributes); + for (const std::pair& p : op->attrs->attributes) { + attr[p.first] = AD(attr[p.first], bp); } if (const IntrinsicIdNode* iin = op->fn.as()) { diff --git a/relay/src/tvm/relay/typeck/shape_evaluator.cc b/relay/src/tvm/relay/typeck/shape_evaluator.cc index 7617173b52cb0..a3b3fbf341dd3 100644 --- a/relay/src/tvm/relay/typeck/shape_evaluator.cc +++ b/relay/src/tvm/relay/typeck/shape_evaluator.cc @@ -72,6 +72,11 @@ struct ConcreteShapeEvaluator : TypeFVisitor { ShapeProjection proj = GetRef(op); Type inner = this->VisitType(proj->shape); + // if we have a type var or type id, we can't simplify further + if (inner.as() || inner.as()) { + return GetRef(op); + } + const ShapeSeqNode *ssn = inner.as(); if (!ssn) { throw ShapeEvaluationError("Can only project into a shape sequence"); @@ -91,6 +96,12 @@ struct ConcreteShapeEvaluator : TypeFVisitor { auto left = this->VisitType(sn->left); auto right = this->VisitType(sn->right); + // if we have a type var or type id on either side, we can't simplify further + if (left.as() || left.as() + || right.as() || right.as()) { + return sn; + } + // otherwise, if they're both sequences, treat as though operators // were applied to each member if (left.as() && right.as()) { diff --git a/relay/src/tvm/relay/typeck/typechecker.cc b/relay/src/tvm/relay/typeck/typechecker.cc index ed872f119d21d..8dbc82a40bd50 100644 --- a/relay/src/tvm/relay/typeck/typechecker.cc +++ b/relay/src/tvm/relay/typeck/typechecker.cc @@ -171,7 +171,17 @@ Type Typechecker::VisitExpr_(const FunctionNode *op) { std::vector arg_types; for (auto arg : f->params) { this->Check(arg); - Type arg_type = simple_eval_shape(arg->type); + Type arg_type; + // if arg type can be simply evaluated, try it + // should be replaced with symbolic evaluation once it exists, + // you will not have attr information at this point + try { + arg_type = simple_eval_shape(arg->type); + } + catch (const dmlc::Error &e) { + this->report_error(e.what(), arg->span); + arg_type = arg->type; + } arg_types.push_back(arg_type); this->local_stack.insert(arg->id, arg_type); } @@ -180,7 +190,8 @@ Type Typechecker::VisitExpr_(const FunctionNode *op) { // TODO(sslyu): should the unified return type override the annotated one? Type checked_return = this->Check(f->body); Type unified = - this->unify(simple_eval_shape(f->ret_type), checked_return, f->span); + this->unify(simple_eval_shape(f->ret_type), + simple_eval_shape(checked_return), f->span); this->local_stack.pop_frame(); // function type is tuple of args -> return type (no currying unless manual) @@ -540,7 +551,8 @@ void Typechecker::fatal_error(std::string msg, Span sp) { this->env->report_error(msg, sp); throw FatalTypeError( "internal error: this exception should" - "be handled and errors reported with Environment::display_errors"); + "be handled and errors reported with Environment::display_errors\n" + + msg); } Type Typechecker::unify(const Type &t1, const Type &t2, Span sp) { diff --git a/relay/src/tvm/relay/typeck/unifier.cc b/relay/src/tvm/relay/typeck/unifier.cc index 830a14bd29e42..c59d03623ff1e 100644 --- a/relay/src/tvm/relay/typeck/unifier.cc +++ b/relay/src/tvm/relay/typeck/unifier.cc @@ -419,7 +419,7 @@ Type TypeUnifierNode::VisitType_(const ShapeAttrNode *sn1, const Type t2) { // can only unify with another shape attr with the same id if (const ShapeAttrNode *sn2 = t2.as()) { - if (s1->id == sn2->id) { + if (s1->id->name == sn2->id->name) { return s1; } throw UnificationError( diff --git a/relay/tests/python/__init__.py b/relay/tests/python/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/relay/tests/python/test_alpha_eq.py b/relay/tests/python/test_alpha_eq.py new file mode 100644 index 0000000000000..97aa7796c1116 --- /dev/null +++ b/relay/tests/python/test_alpha_eq.py @@ -0,0 +1,617 @@ +"""Tests alpha-equivalence between expressions.""" +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.ir import alpha_eq, BOp, ShapeOp, Kind +from relay.typing import TYPE_DEFAULTS +import relay.ir as expr + +INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +def int_type(width=32): + return TensorType(IntType(width), ShapeSeq([])) + +def float_type(width=32): + return TensorType(FloatType(width), ShapeSeq([])) + +def bool_type(): + return TensorType(BoolType(), ShapeSeq([])) + +def nest_quantifiers(ids, body): + ret = body + for tid in reversed(ids): + ret = TypeQuantifier(tid, ret) + return ret + +def test_local_id_not_eq(): + assert not alpha_eq(LocalId("x"), LocalId("y")) + +def test_local_id_eq(): + x = LocalId("x") + assert alpha_eq(x, x) + +def test_global_id_not_eq(): + left = GlobalId("xyz") + right = GlobalId("xyz") + assert not alpha_eq(left, right) + +def test_global_id_eq(): + ident = GlobalId("xyz") + assert alpha_eq(ident, ident) + +def test_intrinsic_id_not_eq(): + left = IntrinsicId("xyz") + right = IntrinsicId("xyz") + # pointer identity, not name, is the rule + assert not alpha_eq(left, right) + +def test_intrinsic_id_eq(): + x = IntrinsicId("xyz") + assert alpha_eq(x, x) + +def test_float_literal_eq(): + x = FloatLit(1.0) + y = FloatLit(1.0) + assert alpha_eq(x, y) + +def test_float_literal_not_eq(): + x = FloatLit(1.0) + y = FloatLit(2.0) + assert not alpha_eq(x, y) + +def test_int_literal_eq(): + x = IntLit(1) + y = IntLit(1) + assert alpha_eq(x, y) + +def test_int_literal_not_eq(): + x = IntLit(1) + y = IntLit(2) + assert not alpha_eq(x, y) + +def test_bool_literal_eq(): + x = BoolLit(True) + y = BoolLit(True) + assert alpha_eq(x, y) + +def test_bool_literal_not_eq(): + x = BoolLit(True) + y = BoolLit(False) + assert not alpha_eq(x, y) + +def test_tensor_literal_eq(): + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_tensor_literal_not_eq(): + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(3)]) + z = TensorLit([IntLit(1)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_product_literal_eq(): + x = ProductLit([IntLit(1), IntLit(2)]) + y = ProductLit([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_product_literal_not_eq(): + x = ProductLit([IntLit(1), IntLit(2)]) + y = ProductLit([IntLit(2), IntLit(2)]) + z = ProductLit([IntLit(1), IntLit(2), IntLit(3)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_projection_eq(): + prod = ProductLit([IntLit(3), FloatLit(3.5)]) + + assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) + assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +def test_projection_not_eq(): + prod1 = ProductLit([IntLit(3), IntLit(4)]) + prod2 = ProductLit([IntLit(3)]) + prod3 = ProductLit([IntLit(3), IntLit(4), FloatLit(3.5)]) + + assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) + assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +def test_cast_not_eq(): + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(1)) + assert not alpha_eq(left, right) + + # same literal, different type + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(2), IntLit(2)) + assert not alpha_eq(left, right) + +def test_cast_eq(): + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(2)) + assert alpha_eq(left, right) + +def test_param_not_eq(): + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("foo"), bool_type()) + assert not alpha_eq(left, right) + +def test_param_eq(): + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("bar"), int_type()) + assert alpha_eq(left, right) + +def test_function_not_eq(): + params1 = [Param(LocalId("x"), int_type())] + fn1 = Function(params1, int_type(), LocalId("x")) + params2 = [Param(LocalId("y"), bool_type())] + fn2 = Function(params2, int_type(), LocalId("y")) + assert not alpha_eq(fn1, fn2) + + params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] + fn3 = Function(params3, int_type(), LocalId("z")) + assert not alpha_eq(fn1, fn3) + +def test_function_eq(): + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function(params1, int_type(), x) + params2 = [Param(y, int_type())] + fn2 = Function(params2, int_type(), y) + assert alpha_eq(fn1, fn2) + +def test_call_not_eq(): + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function(params1, int_type(), x) + args1 = [IntLit(1)] + call1 = Call(fn1, args1) + + args2 = [IntLit(2)] + call2 = Call(fn1, args2) + assert not alpha_eq(call1, call2) + + params2 = [Param(y, int_type())] + fn2 = Function(params2, float_type(), FloatLit(0.0)) + call3 = Call(fn2, args1) + assert not alpha_eq(call1, call3) + assert not alpha_eq(call2, call3) + +def test_call_eq(): + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function(params1, int_type(), x) + args = [IntLit(1)] + call1 = Call(fn1, args) + + params2 = [Param(y, int_type())] + fn2 = Function(params2, int_type(), y) + call2 = Call(fn2, args) + assert alpha_eq(call1, call2) + +def test_debug_not_eq(): + left = Debug(IntLit(1)) + right = Debug(IntLit(2)) + assert not alpha_eq(left, right) + +def test_debug_eq(): + left = Debug(IntLit(1)) + right = Debug(IntLit(1)) + assert alpha_eq(left, right) + +def test_unary_op_not_eq(): + left = UnaryOp(expr.UOp.NEG, FloatLit(2.0)) + right = UnaryOp(expr.UOp.NEG, IntLit(1)) + assert not alpha_eq(left, right) + +def test_unary_op_eq(): + left = UnaryOp(expr.UOp.NEG, IntLit(1)) + right = UnaryOp(expr.UOp.NEG, IntLit(1)) + assert alpha_eq(left, right) + +def test_binary_op_not_eq(): + left = BinaryOp(BOp.PLUS, IntLit(1), IntLit(2)) + right = BinaryOp(BOp.PLUS, FloatLit(1.0), FloatLit(2.0)) + assert not alpha_eq(left, right) + +def test_binary_op_eq(): + left = BinaryOp(expr.BOp.PLUS, IntLit(1), IntLit(2)) + right = BinaryOp(expr.BOp.PLUS, IntLit(1), IntLit(2)) + assert alpha_eq(left, right) + +def test_reverse_not_eq(): + left = Reverse(FloatLit(2.0)) + right = Reverse(IntLit(2)) + assert not alpha_eq(left, right) + +def test_reverse_eq(): + left = Reverse(IntLit(2)) + right = Reverse(IntLit(2)) + assert alpha_eq(left, right) + +def test_zero_not_eq(): + left = Zero(int_type()) + right = Zero(float_type()) + assert not alpha_eq(left, right) + +def test_zero_eq(): + left = Zero(int_type()) + right = Zero(int_type()) + assert alpha_eq(left, right) + +def test_let_not_eq(): + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), IntLit(11)) + let2 = Let(y, int_type(), IntLit(10), IntLit(12)) + assert not alpha_eq(let1, let2) + + let3 = Let(x, int_type(), IntLit(10), x) + let4 = Let(y, int_type(), IntLit(12), y) + assert not alpha_eq(let3, let4) + +def test_let_eq(): + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), x) + let2 = Let(y, int_type(), IntLit(10), y) + assert alpha_eq(let1, let2) + +def test_ref_eq(): + r1 = Ref(IntLit(5)) + r2 = Ref(BinaryOp(BOp.MUL, IntLit(5), IntLit(6))) + assert alpha_eq(r1, r1) + assert alpha_eq(r2, r2) + +def test_ref_not_eq(): + r1 = Ref(IntLit(5)) + r2 = Ref(FloatLit(3.5)) + r3 = Ref(r1) + assert not alpha_eq(r1, r2) + assert not alpha_eq(r1, r3) + assert not alpha_eq(r2, r3) + +def test_val_ref_eq(): + vr1 = ValRef(Ref(IntLit(35))) + vr2 = ValRef(Ref(ProductLit([IntLit(12), FloatLit(2.5)]))) + assert alpha_eq(vr1, vr1) + assert alpha_eq(vr2, vr2) + +def test_val_ref_not_eq(): + vr1 = ValRef(Ref(IntLit(5))) + vr2 = ValRef(Ref(vr1)) + vr3 = ValRef(Ref(FloatLit(5.0))) + assert not alpha_eq(vr1, vr2) + assert not alpha_eq(vr1, vr3) + assert not alpha_eq(vr2, vr3) + +def test_set_ref_eq(): + sr1 = SetRef(Ref(FloatLit(5.0)), FloatLit(6.0)) + sr2 = SetRef(Ref(ProductLit([IntLit(3), BoolLit(False)])), + ProductLit([IntLit(5), BoolLit(True)])) + assert alpha_eq(sr1, sr1) + assert alpha_eq(sr2, sr2) + +def test_set_ref_not_eq(): + r1 = Ref(FloatLit(5.0)) + r2 = Ref(IntLit(5)) + r3 = Ref(IntLit(6)) + + assert not alpha_eq(SetRef(r1, FloatLit(6.0)), + SetRef(r2, IntLit(6))) + assert not alpha_eq(SetRef(r2, IntLit(6)), SetRef(r2, IntLit(7))) + assert not alpha_eq(SetRef(r2, IntLit(7)), SetRef(r3, IntLit(7))) + +# Type alpha-equality tests + +def test_base_type_eq(): + assert alpha_eq(IntType(32), IntType(32)) + assert alpha_eq(BoolType(), BoolType()) + assert alpha_eq(FloatType(32), FloatType(32)) + +def test_tensor_type_eq(): + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert alpha_eq(tt1, tt1) + assert alpha_eq(tt2, tt2) + +def test_tensor_type_not_eq(): + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt3 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert not alpha_eq(tt1, tt2) + assert not alpha_eq(tt1, tt3) + +def test_ref_type_eq(): + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert alpha_eq(rt1, rt1) + assert alpha_eq(rt2, rt2) + +def test_ref_type_not_eq(): + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert not alpha_eq(rt1, rt2) + +def test_product_type_eq(): + pt1 = ProductType([int_type(), RefType(float_type())]) + pt2 = ProductType([float_type(), float_type(), int_type()]) + assert alpha_eq(pt1, pt1) + assert alpha_eq(pt2, pt2) + +def test_product_type_not_eq(): + pt1 = ProductType([int_type(), int_type()]) + pt2 = ProductType([int_type(), int_type(), float_type()]) + pt3 = ProductType([bool_type(), float_type()]) + assert not alpha_eq(pt1, pt2) + assert not alpha_eq(pt1, pt3) + +def test_type_id_eq(): + id1 = TypeId("id1", Kind.Shape) + id2 = TypeId("id2", Kind.BaseType) + id3 = TypeId("id2", Kind.Type) + + assert alpha_eq(id1, id1) + assert alpha_eq(id2, id2) + assert alpha_eq(id3, id3) + +def test_type_id_not_eq(): + # name is just a hint, we use pointer equality as the rule + # (unless there is a quantifier to give context) + id1 = TypeId("id1", Kind.Shape) + id2 = TypeId("id1", Kind.Shape) + id3 = TypeId("id3", Kind.BaseType) + + assert not alpha_eq(id1, id2) + assert not alpha_eq(id1, id3) + +def test_arrow_type_eq(): + ar1 = TypeArrow(int_type(), bool_type()) + ar2 = TypeArrow(ProductType([int_type(), int_type()]), ProductType([])) + assert alpha_eq(ar1, ar1) + assert alpha_eq(ar2, ar2) + +def test_arrow_type_not_eq(): + t1 = int_type() + t2 = bool_type() + t3 = ProductType([int_type(), bool_type()]) + + assert not alpha_eq(TypeArrow(t1, t2), TypeArrow(t1, t1)) + assert not alpha_eq(TypeArrow(t3, t1), TypeArrow(t2, t1)) + assert not alpha_eq(TypeArrow(t1, TypeArrow(t1, t1)), + TypeArrow(t1, t1)) + +def test_type_quantifier_eq(): + id1 = TypeId("id1", Kind.Shape) + id2 = TypeId("id2", Kind.Shape) + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_nested_type_quantifier_eq(): + id1 = TypeId("id1", Kind.BaseType) + id2 = TypeId("id2", Kind.Shape) + id3 = TypeId("id3", Kind.BaseType) + id4 = TypeId("id4", Kind.Shape) + tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) + tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_type_quantifier_not_eq(): + id1 = TypeId("id1", Kind.Shape) + id2 = TypeId("id2", Kind.BaseType) + id3 = TypeId("id3", Kind.Shape) + + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) + tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) + tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + + assert not alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq2, tq4) + +def test_shape_singleton_eq(): + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(10) + + assert alpha_eq(single1, single1) + assert alpha_eq(single1, single2) + +def test_shape_singelton_not_eq(): + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(11) + + assert not alpha_eq(single1, single2) + +def test_shape_attr_eq(): + attr1 = ShapeAttr(String("x")) + attr2 = ShapeAttr(String("x")) + + assert alpha_eq(attr1, attr1) + assert alpha_eq(attr1, attr2) + +def test_shape_attr_not_eq(): + id1 = String("x") + id2 = String("y") + attr1 = ShapeAttr(id1) + attr2 = ShapeAttr(id2) + + assert not alpha_eq(attr1, attr2) + +def test_shape_seq_eq(): + empty = ShapeSeq([]) + seq1 = ShapeSeq([ShapeSingleton(5)]) + seq2 = ShapeSeq([ShapeSingleton(5)]) + + assert alpha_eq(empty, empty) + assert alpha_eq(seq1, seq2) + +def test_shape_seq_not_eq(): + empty = ShapeSeq([]) + seq = ShapeSeq([ShapeSingleton(5)]) + single = ShapeSingleton(5) + + assert not alpha_eq(empty, seq) + assert not alpha_eq(seq, single) + +def test_shape_projection_eq(): + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + + assert alpha_eq(proj1, proj2) + +def test_shape_projection_not_eq(): + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) + proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) + proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + + assert not alpha_eq(proj1, proj2) + assert not alpha_eq(proj1, proj3) + assert not alpha_eq(proj1, proj4) + assert not alpha_eq(proj2, proj3) + assert not alpha_eq(proj2, proj4) + assert not alpha_eq(proj3, proj4) + +def test_shape_binary_op_eq(): + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) + op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) + op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) + op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + + assert alpha_eq(op1, op1) + assert alpha_eq(op2, op2) + assert alpha_eq(op3, op3) + assert alpha_eq(op4, op4) + +def test_shape_binary_op_not_eq(): + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, single, single), + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([single]), + ShapeSeq([single]))) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), + ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), + ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +def test_shape_nested_in_quantifier(): + b1 = TypeId("b", Kind.BaseType) + x1 = TypeId("x", Kind.Shape) + y1 = TypeId("y", Kind.Shape) + + b2 = TypeId("b", Kind.BaseType) + x2 = TypeId("x", Kind.Shape) + y2 = TypeId("y", Kind.Shape) + + b3 = TypeId("b", Kind.BaseType) + x3 = TypeId("x", Kind.Shape) + y3 = TypeId("y", Kind.Shape) + + tq1 = nest_quantifiers( + [b1, x1, y1], + TypeArrow( + ProductType([TensorType(b1, x1), TensorType(b1, y2)]), + TensorType( + b1, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x1, ShapeProjection(y1, 1), + ShapeSingleton(5), ShapeAttr(String("att"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq2 = nest_quantifiers( + [b2, x2, y2], + TypeArrow( + ProductType([TensorType(b2, x2), TensorType(b2, y2)]), + TensorType( + b2, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x2, ShapeProjection(y2, 1), + ShapeSingleton(5), ShapeAttr(String("att"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + # different attr, var order, position, and constant + tq3 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + ProductType([TensorType(b3, x3), TensorType(b3, y3)]), + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(4), ShapeAttr(String("att"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq4 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + ProductType([TensorType(b3, x3), TensorType(b3, y3)]), + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 2), + ShapeSingleton(5), ShapeAttr(String("att2"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq5 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + ProductType([TensorType(b3, x3), TensorType(b3, y3)]), + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHMUL, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr(String("att"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq6 = nest_quantifiers( + [b3, y3, x3], + TypeArrow( + ProductType([TensorType(b3, x3), TensorType(b3, y3)]), + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr(String("att"))]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + assert alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq4) + assert not alpha_eq(tq1, tq5) + assert not alpha_eq(tq2, tq5) + assert not alpha_eq(tq1, tq6) + assert not alpha_eq(tq2, tq6) diff --git a/relay/tests/python/test_ast.py b/relay/tests/python/test_ast.py new file mode 100644 index 0000000000000..aebe16e4a32bc --- /dev/null +++ b/relay/tests/python/test_ast.py @@ -0,0 +1,236 @@ +# pylint: disable-all +import tvm +import relay.ir as ir +from relay.visitor import Visitor +import relay.eval as re +from relay.make import * + +def test_string_literal(): + x = String("xyz") + assert isinstance(x, ir.String) + assert x.name == "xyz" + + +def test_float_literal(): + x = FloatLit(1.0) + assert isinstance(x, ir.FloatLit) + assert x.value == 1.0 + + +def test_int_literal(): + x = IntLit(1) + assert isinstance(x, ir.IntLit) + assert x.value == 1 + + +def test_bool_literal(): + x = BoolLit(False) + assert isinstance(x, ir.BoolLit) + assert x.value == False + + +def test_tensor_literal(): + x = TensorLit([IntLit(1), IntLit(2)]) + assert isinstance(x, ir.TensorLit) + assert x.data[0].value == 1 + assert x.data[1].value == 2 + + +def test_product_literal(): + x = ProductLit([IntLit(1), IntLit(2)]) + assert isinstance(x, ir.ProductLit) + assert x.fields[0].value == 1 + assert x.fields[1].value == 2 + + +def test_int_type1(): + ty = IntType(1) + # TODO: uncomment when IntType works correctly + # assert isinstance(ty, ir.IntType) + # assert ty.width == 1 + + +def test_cast(): + x = Cast(IntType(1), IntLit(1)) + assert isinstance(x, ir.Cast) + assert x.target == IntType(1) + assert x.node == IntLit(1) + + +def test_local_id(): + x = LocalId("xyz") + assert isinstance(x, ir.LocalId) + x.name == "xyz" + + +def test_global_id(): + x = GlobalId("xyz") + assert isinstance(x, ir.GlobalId) + x.name == "xyz" + + +def test_intrinsic_id(): + x = IntrinsicId("xyz") + assert isinstance(x, ir.IntrinsicId) + x.name == "xyz" + + +def test_param(): + x = Param(LocalId("foo"), IntType(1)) + assert isinstance(x, ir.Param) + # should use eq + + +def test_function(): + params = [Param(LocalId("x"), IntType(2))] + fn = Function(params, IntType(2), LocalId("x")) + assert isinstance(fn, ir.Function) + assert fn.params[0] == params[0] + assert isinstance(fn.body, ir.LocalId) + assert fn.body.name == "x" + + +def test_call(): + params = [Param(LocalId("x"), IntType(2))] + fn = Function(params, IntType(32), LocalId("x")) + args = [IntLit(1)] + call = Call(fn, args) + assert isinstance(call, ir.Call) + assert call.fn == fn + assert call.args[0] == args[0] + + +def test_call_with_attrs(): + params = [Param(LocalId("x"), IntType(2))] + fn = Function(params, IntType(32), LocalId("x")) + args = [IntLit(1)] + attr = String("attr") + call = Call(fn, args, Attributes({attr: IntLit(10)})) + assert isinstance(call, ir.Call) + assert call.fn == fn + assert call.args[0] == args[0] + + +def test_debug(): + x = Debug(IntLit(1)) + assert isinstance(x, ir.Debug) + assert x.node.value == IntLit(1).value + + +def test_unary_op(): + uop = UnaryOp(ir.UOp.NEG, IntLit(1)) + assert uop.node.value == 1 + + +def test_binary_op(): + x = BinaryOp(ir.BOp.PLUS, IntLit(1), IntLit(2)) + assert IntLit(1) == x.left + assert IntLit(2) == x.right + + +def test_reverse(): + rev = Reverse(IntLit(1)) + assert rev.node == IntLit(1) + + +def test_zero(): + zero = Zero(IntType(32)) + assert zero.type == IntType(32) + + +def test_let(): + x = LocalId("x") + lt = Let(x, IntType(32), IntLit(10), IntLit(11)) + assert lt.body == IntLit(11) + + +def test_if(): + guard = BinaryOp(ir.BOp.LT, IntLit(1), IntLit(2)) + ite = If(guard, IntLit(10), IntLit(11)) + assert ite.guard == guard + assert ite.true_b == IntLit(10) + assert ite.false_b == IntLit(11) + + +def test_defn(): + params = [Param(LocalId("x"), IntType(2))] + fn = Function(params, IntType(2), LocalId("x")) + foo = GlobalId("foo") + defn = Defn(foo, IntType(1), fn) + assert defn.body == fn + assert defn.id == foo + + +def test_primitive(): + assert True # TODO + + +def test_int_type2(): + it = IntType(10) + assert it.dtype == 'int10' + + +def test_bool_type(): + bt = BoolType() + +def test_type_var(): + tv1 = TypeVar(1) + tv2 = TypeVar(2) + assert tv1 != tv2 + + +def test_type_id(): + ti = TypeId("hello") + assert ti.name == "hello" + + +def test_type_quantifier(): + ident = TypeId("id") + body = BoolType() + tq = TypeQuantifier(ident, body) + assert tq.id.name == "id" + assert tq.boundType.dtype == 'uint1' + + +def test_int_value(): + iv = IntValue(1337) + assert iv.value == 1337 + + +def test_bool_value(): + bv = BoolValue(True) + assert bv.value == True + + +def test_float_value(): + fv = FloatValue(1.135) + assert fv.value == 1.135 + + +def test_fn_value(): + params = [Param(LocalId("x"), IntType(2))] + fn = Function(params, IntType(2), LocalId("x")) + fn_val = FnValue({}, fn) + assert fn_val.func == fn + + +def test_attributes(): + foo = String("foo") + attrs = Attributes({foo: IntLit(1)}) + + +def test_tensor_type(): + sh = ShapeSeq([ShapeSingleton(10), ShapeSingleton(10), ShapeSingleton(5)]) + tt = TensorType(FloatType(16), sh) + assert tt.shape == sh + assert tt.dtype == FloatType(16) + +def test_not_tensor_type(): + try: + sh = [IntLit(10), IntLit(10), IntLit(5)] + tt = TensorType(FloatType(16), sh) + print(tt.shape) + assert False + except tvm.TVMError: + assert True + diff --git a/relay/tests/python/test_decorator.py b/relay/tests/python/test_decorator.py new file mode 100644 index 0000000000000..01b26ffc8acbd --- /dev/null +++ b/relay/tests/python/test_decorator.py @@ -0,0 +1,416 @@ +"""Tests for the Python-to-Relay decorator.""" +from typing import Any, no_type_check +# pylint: disable=broad-except, invalid-name +from relay.frontend import relay, get_env +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.ir import BOp, UOp +from relay.make import * +from relay.typing import Int, UInt, Float, Bool, Tensor, Product +import relay.ir as ir + + +# TODO(@weberlo): Test attributes. + + +@no_type_check +def empty_body() -> Int: + pass + +def test_empty_body_fails(): + try: + gen_and_check_func("empty_body") + assert False + except Exception: + assert True + + +@no_type_check +def default_int_type(x: Int) -> Int: + return x + +def test_default_int_type(): + func = gen_and_check_func("default_int_type") + typ = TensorType(Int.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def default_uint_type(x: UInt) -> UInt: + return x + +def test_default_uint_type(): + func = gen_and_check_func("default_uint_type") + typ = TensorType(UInt.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def default_float_type(x: Float) -> Float: + return x + +def test_default_float_type(): + func = gen_and_check_func("default_float_type") + typ = TensorType(Float.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def default_bool_type(x: Bool) -> Bool: + return x + +def test_default_bool_type(): + func = gen_and_check_func("default_bool_type") + typ = TensorType(Bool.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + +@no_type_check +def int_type_with_one_param(x: Int[16]) -> Int[16]: + return x + +def test_int_type_with_one_param(): + func = gen_and_check_func("int_type_with_one_param") + typ = TensorType(IntType(16), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def int_type_with_two_params(x: Int[16, 2]) -> Int[16, 2]: + return x + +def test_int_type_with_two_params(): + func = gen_and_check_func("int_type_with_two_params") + typ = TensorType(IntType(16, 2), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def basic_tensor_type(x: Tensor[Int, 1]) -> Tensor[Int, 1]: + return x + +def test_basic_tensor_type(): + func = gen_and_check_func("basic_tensor_type") + typ = TensorType(Int.default(), as_shape([1])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def wrong_num_tensor_type_args(x: Tensor[Int]) -> Tensor[Int]: + return x + +def test_wrong_num_tensor_type_args_fails(): + try: + gen_and_check_func("wrong_num_tensor_type_args") + assert False + except Exception: + assert True + + +@no_type_check +def wrong_type_shape_args(x: Tensor[Int, (1, int, bool)]) -> Tensor[Int, (10, 10, 10)]: + return x + +def test_wrong_type_shape_args_fails(): + try: + gen_and_check_func("wrong_type_shape_args") + assert False + except Exception: + assert True + + +@no_type_check +def two_dim_tensor_type(x: Tensor[Int, (1, 1)]) -> Tensor[Int, (1, 1)]: + return x + +def test_two_dim_tensor_type(): + func = gen_and_check_func("two_dim_tensor_type") + typ = TensorType(Int.default(), as_shape([1, 1])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def complex_tensor_type(x: Tensor[Int[32, 2], (10, 10, 10)]) -> Tensor[Int[32, 2], (10, 10, 10)]: + return x + +def test_complex_tensor_type(): + func = gen_and_check_func("complex_tensor_type") + typ = TensorType(IntType(32, 2), as_shape([10, 10, 10])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([], LocalId("x"), func.body) + + +@no_type_check +def basic_assign(x: Int) -> Int: + y: Int = x + return y + +def test_basic_assign(): + func = gen_and_check_func("basic_assign") + typ = TensorType(Int.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([ + ("y", typ, LocalId("x")) + ], LocalId("y"), func.body) + + +@no_type_check +def assign_chain(x: Int) -> Int: + y: Int = x + z: Int = y + return z + +def test_assign_chain(): + func = gen_and_check_func("assign_chain") + typ = TensorType(Int.default(), as_shape([])) + assert func.ret_type == typ + check_params([ + ("x", typ), + ], func.params) + check_body([ + ("y", typ, LocalId("x")), + ("z", typ, LocalId("y")) + ], LocalId("z"), func.body) + + +@no_type_check +def use_undef_var() -> Int: + y: Int = x + return y + +def test_use_undef_var(): + try: + gen_and_check_func("use_undef_var") + assert False + except Exception: + assert True + + +# TODO(@weberlo): Uncomment once placeholder types are handled by the unifier. +# def no_assign_annotation(x: Int) -> Int: +# y = x +# return y + +# def test_no_assign_annotation(): +# func = gen_and_check_func("no_assign_annotation") +# typ = TensorType(Int.default(), as_shape([])) +# assert func.ret_type == typ +# check_params([ +# ("x", typ), +# ], func.params) +# check_body([ +# ("y", typ, LocalId("x")) +# ], LocalId("y"), func.body) + + +@no_type_check +def three_tuple_lit() -> Product[Int, Float, Bool]: + return (1, 2.0, True) + +def test_three_tuple_lit(): + func = gen_and_check_func("three_tuple_lit") + typ = ProductType([Int.default(), Float.default(), Bool.default()]) + assert func.ret_type == typ + check_params([], func.params) + check_body( + [], ProductLit([IntLit(1), FloatLit(2.0), BoolLit(True)]), func.body) + + +@no_type_check +def tuple_projection() -> Float: + return (1, 2.0, True)[1] + +def test_tuple_projection(): + func = gen_and_check_func("tuple_projection") + assert func.ret_type == TensorType(Float.default(), as_shape([])) + check_params([], func.params) + # body = Projection( + # ProductLit([IntLit(1), FloatLit(2.0), BoolLit(True)]), 1) + # TODO(@weberlo): WUT IS PROJECTION EQUALITY?! + # check_body([], body, func.body) + + +@no_type_check +def bin_op() -> Int: + return 1 + 2 + +def test_bin_op(): + func = gen_and_check_func("bin_op") + assert func.ret_type == TensorType(Int.default(), as_shape([])) + check_params([], func.params) + check_body( + [], BinaryOp(ir.BOp.PLUS, IntLit(1), IntLit(2)), func.body) + + +@no_type_check +def un_op() -> Int: + return -1 + +def test_un_op(): + func = gen_and_check_func("un_op") + assert func.ret_type == TensorType(Int.default(), as_shape([])) + check_params([], func.params) + check_body( + [], UnaryOp(UOp.NEG, IntLit(1)), func.body) + + +@no_type_check +def cmp_op() -> Bool: + return 1 < 2 + +def test_cmp_op(): + func = gen_and_check_func("cmp_op") + assert func.ret_type == TensorType(Bool.default(), as_shape([])) + check_params([], func.params) + check_body( + [], BinaryOp(BOp.LT, IntLit(1), IntLit(2)), func.body) + + +@no_type_check +def tensor_lit() -> Tensor[Int, 3]: + return [1, 2, 3] + +def test_tensor_lit(): + func = gen_and_check_func("tensor_lit") + assert func.ret_type == TensorType(Int.default(), as_shape([3])) + check_params([], func.params) + check_body( + [], TensorLit([IntLit(1), IntLit(2), IntLit(3)]), func.body) + + +@no_type_check +def nested_tensor_lit() -> Tensor[Int, (3, 2)]: + return [[1, 2], [3, 4], [5, 6]] + +def test_nested_tensor_lit(): + func = gen_and_check_func("nested_tensor_lit") + assert func.ret_type == TensorType(Int.default(), as_shape([3, 2])) + check_params([], func.params) + expected_body = TensorLit([ + TensorLit([IntLit(1), IntLit(2)]), + TensorLit([IntLit(3), IntLit(4)]), + TensorLit([IntLit(5), IntLit(6)]), + ]) + check_body([], expected_body, func.body) + + +# pylint: disable=unused-argument +@no_type_check +def ident_typo(target: Int, other: Int) -> Int: + return trageet + +def test_ident_typo(): + try: + gen_and_check_func("ident_typo") + assert False + except Exception: + assert True + + +# TODO(@weberlo): Enable this test when span serialization is fixed. + +# def span_info(x: Int, y: Int) -> Int: +# z = x + y +# return z + +# def test_span_info(): +# def check_span(node, exp_lineno, exp_col_offset): +# assert node.span.lineno == exp_lineno +# assert node.span.col_offset == exp_col_offset + +# func = gen_and_check_func("span_info") +# check_span(func, 1, 0) +# # first param +# node = func.params[0] +# check_span(node, 1, 14) +# # second param +# node = func.params[1] +# check_span(node, 1, 22) +# # return type +# node = func.ret_type +# check_span(node, 1, 33) +# # function body +# node = func.body +# check_span(node, 2, 4) +# # assignment rhs +# node = node.value +# check_span(node, 2, 8) + + +# +# Utility Methods +# + +def gen_and_check_func(py_func_name: str) -> Any: + """Compiles the given function and does a few sanity checks.""" + _ = relay(globals()[py_func_name]) + curr_env = get_env() + global_id = curr_env.global_id(py_func_name) + defn = curr_env.lookup(global_id) + assert isinstance(defn, ir.Defn) + func = defn.body + assert isinstance(func, ir.Function) + return func + +def check_params(expected, actual): + """Checks types and names of function parameters.""" + assert len(expected) == len(actual) + for ((exp_name, exp_type), act) in zip(expected, actual): + assert act.id.name == exp_name + assert act.type == exp_type + +def check_body(expected_bindings, expected_ret_val, actual_body): + """Checks the contents of a function that's in the form of a series of + assignments with a return value.""" + curr_binding = actual_body + for (exp_name, exp_type, exp_rhs) in expected_bindings: + assert isinstance(curr_binding, ir.Let) + assert curr_binding.id.name == exp_name + assert curr_binding.type == exp_type + assert structurally_equal(curr_binding.value, exp_rhs) + curr_binding = curr_binding.body + assert structurally_equal(expected_ret_val, curr_binding) + +def structurally_equal(a, b): + if isinstance(a, ir.LocalId): + return a.name == b.name + return a == b + +def as_shape(dims): + return ShapeSeq(list(map(ShapeSingleton, dims))) diff --git a/relay/tests/python/test_eval.py b/relay/tests/python/test_eval.py new file mode 100644 index 0000000000000..6092323d7b852 --- /dev/null +++ b/relay/tests/python/test_eval.py @@ -0,0 +1,152 @@ +# type: ignore +# pylint: skip-file +import tvm +from tvm._ffi.ndarray import _make_array +from tvm._ffi.runtime_ctypes import TVMArrayHandle +import relay.ir as expr +import numpy as np +from relay.make import * +from relay.visitor import Visitor +from relay.typing import Float, Tensor, Int, UInt +import relay.eval as re +import ctypes +from relay.frontend import relay, get_env +from typing import no_type_check + +@no_type_check +@relay +def one_plus_one_int() -> Int: + return 1 + 1 + +def test_eval_binop_plus_int(): + assert one_plus_one_int() == 2 + +@no_type_check +@relay +def one_sub_one_int() -> Int: + return 1 - 1 + +def test_eval_binop_sub_int(): + assert one_sub_one_int() == 0 + +@no_type_check +@relay +def one_mul_one_int() -> Int: + return 1 * 1 + +def test_eval_binop_mul_int(): + assert one_mul_one_int() == 1 + +@no_type_check +@relay +def one_div_one_int() -> Int: + return 1 / 1 + +def test_eval_binop_div_int(): + assert one_div_one_int() == 1 + +@no_type_check +@relay +def one_plus_one_float() -> Float: + return 1.0 + 1.0 + +def test_eval_binop_plus_float(): + assert one_plus_one_float() == 2.0 + +@no_type_check +@relay +def one_sub_one_float() -> Float: + return 1.0 - 1.0 + +def test_eval_binop_sub_float(): + assert one_sub_one_float() == 0.0 + +@no_type_check +@relay +def ten_mul_one_float() -> Float: + return 10.0 * 1.0 + +def test_eval_binop_mul_float(): + assert ten_mul_one_float() == 10.0 + +@no_type_check +@relay +def ten_div_one_float() -> Float: + return 10.0 / 1.0 + +def test_eval_binop_div_float(): + assert ten_div_one_float() == 10.0 + +@no_type_check +@relay +def one_plus_one_uint() -> UInt: + return 1 + 1 + +def test_eval_binop_plus_uint(): + assert one_plus_one_uint() == 2 + +@no_type_check +@relay +def one_sub_one_uint() -> UInt: + return 1.0 - 1.0 + +def test_eval_binop_sub_uint(): + assert one_sub_one_uint() == 0.0 + +@no_type_check +@relay +def ten_mul_one_uint() -> UInt: + return 10.0 * 1.0 + +def test_eval_binop_mul_uint(): + assert ten_mul_one_uint() == 10.0 + +@no_type_check +@relay +def ten_div_one_uint() -> UInt: + return 10.0 / 1.0 + +def test_eval_binop_div_uint(): + assert ten_div_one_uint() == 10.0 + +@no_type_check +@relay +def neg_ten_int() -> Int: + return -10 + +# def test_eval_uop_neg(): +# assert neg_ten() == -1 + +# mypy: ignore +# @relay +# def tvm_tanh(x: Tensor[Float, (10, 10)]) -> Tensor[Float, (10, 10)]: +# return relay.tanh(x) + +# @relay +# def tvm_log(x: Tensor[Float, (10, 10)]) -> Tensor[Float, (10, 10)]: +# return relay.log(x) + +# @relay +# def tvm_softmax(x: Tensor[Float, (10, 10)]) -> Tensor[Float, (10, 10)]: +# return relay.softmax(x) + +# def test_eval_tanh(): +# in_array = np.random.uniform(size=10) +# out_array = tvm_tanh(in_array) +# np.testing.assert_allclose(out_array.asnumpy(), np.tanh(in_array), rtol=1e-5) + +# def test_eval_log(): +# in_array = np.random.uniform(size=10) +# out_array = tvm_log(in_array) +# np.testing.assert_allclose(out_array.asnumpy(), np.log(in_array), rtol=1e-5) + +# # https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python +# def np_softmax(x): +# """Compute softmax values for each sets of scores in x.""" +# e_x = np.exp(x - np.max(x)) +# return e_x / np.sum(e_x, axis=0) + +# def test_eval_softmax(): +# in_array = np.random.uniform(size=10) +# out_array = tvm_softmax(in_array) +# np.testing.assert_allclose(out_array.asnumpy(), np_softmax(in_array), rtol=1e-5) diff --git a/relay/tests/python/test_forward_ad.py b/relay/tests/python/test_forward_ad.py new file mode 100644 index 0000000000000..718b02e76da58 --- /dev/null +++ b/relay/tests/python/test_forward_ad.py @@ -0,0 +1,20 @@ +"""Tests forward-mode autodiff.""" +# import tvm +# import nnvm +# import nnvm.relay.ir as expr +# from nnvm.relay.make import * +# from nnvm.relay.visitor import Visitor +# import nnvm.relay.forward_ad as ad +# import nnvm.relay.pretty_printer as pp +# from nnvm.relay.frontend import get_env + +# Fix MK +# +# def test_ad(): +# x = FloatLit(42.0) # todo(M.K.) x = FloatLit(42) dont work +# y = FloatLit(12.0) +# z = LocalId("z") +# lt = Let(z, x * y, z + z) +# adres = ad.forward_ad(lt) +# print(pp.pretty_print(get_env(), adres)) +# # assert isinstance(adres, expr.ProductLit) todo(M.K.) fix by asserting type checking diff --git a/relay/tests/python/test_grad_descent.py b/relay/tests/python/test_grad_descent.py new file mode 100644 index 0000000000000..92ee639f263f7 --- /dev/null +++ b/relay/tests/python/test_grad_descent.py @@ -0,0 +1,31 @@ +#pylint: disable-all +"""End-to-end test for scalar linear regression.""" +from typing import no_type_check +# import nnvm.relay.eval as re +from relay.frontend import relay, get_env +from relay.typing import Float, Product + +# # pylint: disable=invalid-name +# @no_type_check +# def grad_descent(w: Float, x: Float, y: Float, lr: Float) -> Float: +# def infer(w: Float, x: Float) -> Float: +# return w * x +# def loss(w: Float, x: Float, y: Float) -> Float: +# diff: Float = infer(w, x) - y +# return diff**2 +# grad: Product[Float, Float, Float] = relay.grad(loss)(w, x, y) +# # Logan needs to handle nested calls +# grad_w: Float = grad[0] +# return lr * grad_w + +# def test_grad_descent(): +# _ = relay(grad_descent) +# env = get_env() +# global_id = env.global_id("grad_descent") +# defn = env.lookup(global_id) +# grad_descent_compiled = defn.body + +# print(grad_descent_compiled) +# # print(pp.pretty_print(get_env(), grad_descent_compiled)) + +# # print(re.eval(get_env(), test_grad_descent())) diff --git a/relay/tests/python/test_kindchecker.py b/relay/tests/python/test_kindchecker.py new file mode 100644 index 0000000000000..e1f4bc444e20c --- /dev/null +++ b/relay/tests/python/test_kindchecker.py @@ -0,0 +1,128 @@ +"""Tests the type system's kind checker.""" +# pylint: disable=invalid-name +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.ir import ShapeOp, Kind +from relay.kindchecker import check_kind + +def test_basetypes(): + assert check_kind(IntType(32)) + assert check_kind(BoolType()) + assert check_kind(FloatType(32)) + +def test_typeid(): + assert check_kind(TypeId("x", Kind.Shape)) + assert check_kind(TypeId("x", Kind.BaseType)) + assert check_kind(TypeId("x", Kind.Type)) + +def test_shape_singleton(): + assert check_kind(ShapeSingleton(5)) + +def test_shape_attr(): + ident = String("x") + assert check_kind(ShapeAttr(ident)) + +def test_shape_seq(): + ident = String("x") + assert check_kind(ShapeSeq([ShapeSingleton(5), ShapeSingleton(6), ShapeSingleton(7)])) + assert check_kind( + ShapeSeq([ShapeSingleton(5), + ShapeAttr(ident), ShapeSingleton(7)])) + assert check_kind( + ShapeSeq([ + ShapeBinaryOp( + ShapeOp.SHPLUS, + ShapeSeq([ShapeSingleton(5), + ShapeSingleton(6)]), + ShapeAttr(ident) + ), ShapeSingleton(7), + ShapeProjection( + ShapeSeq([ + ShapeSingleton(10), + ShapeSingleton(6)]), + 0) + ])) + +def test_shape_binary_op(): + assert check_kind(ShapeBinaryOp(ShapeOp.SHPLUS, ShapeSingleton(5), ShapeSeq([]))) + assert check_kind(ShapeBinaryOp( + ShapeOp.SHSUB, + ShapeProjection(ShapeSeq([ShapeSingleton(5), ShapeSingleton(6)]), 0), ShapeSeq([]))) + +def test_shape_projection(): + assert check_kind(ShapeProjection(ShapeSingleton(5), 0)) + assert check_kind(ShapeProjection(ShapeSeq([ShapeSingleton(5), ShapeSingleton(7)]), 1)) + +def test_tensor_type(): + ident = String("x") + assert check_kind(TensorType(IntType(32), ShapeAttr(ident))) + assert check_kind(TensorType(IntType(32), ShapeSeq([]))) + assert check_kind(TensorType(BoolType(), ShapeSingleton(10))) + assert check_kind( + TensorType(FloatType(32), + ShapeBinaryOp( + ShapeOp.SHMUL, + ShapeSeq([ + ShapeSingleton(5), + ShapeSingleton(6)]), + ShapeSeq([ShapeSingleton(8), + ShapeSingleton(8)])))) + +def test_type_quantifier(): + x = TypeId("x", Kind.Type) + y = TypeId("y", Kind.BaseType) + z = TypeId("z", Kind.Shape) + assert check_kind(TypeQuantifier(x, x)) + assert check_kind(TypeQuantifier(x, ProductType([x, x, x]))) + assert check_kind(TypeQuantifier(y, TensorType(y, ShapeSeq([])))) + assert check_kind(TypeQuantifier(z, TensorType(IntType(32), ShapeSeq([z, ShapeSingleton(5)])))) + +def test_arrow_type(): + assert check_kind(TypeArrow(ProductType([]), ProductType([]))) + assert check_kind(TypeArrow(TensorType(BoolType(), ShapeSeq([])), ProductType([]))) + +def test_product_type(): + assert check_kind(ProductType([])) + assert check_kind( + ProductType( + [ProductType([]), + TensorType( + IntType(32), + ShapeSeq([ + ShapeSingleton(5), + ShapeSingleton(6) + ])) + ])) + +def test_holes(): + assert check_kind(TypeVar(1)) + assert check_kind(TensorType(TypeVar(2), TypeVar(3))) + assert check_kind(ShapeSeq([TypeVar(2), ShapeSingleton(3)])) + assert check_kind(TypeQuantifier(TypeId("x", Kind.Type), TypeVar(3))) + +def test_bad_tensor_type(): + assert not check_kind(TensorType(ShapeSeq([]), IntType(32))) + assert not check_kind(TensorType(ProductType([]), ProductType([]))) + +def test_bad_arrow(): + assert not check_kind(TypeArrow(ShapeSeq([]), IntType(32))) + +def test_bad_product(): + assert not check_kind(ProductType([ShapeSingleton(5), IntType(32), ShapeSingleton(10)])) + +def test_bad_shapes(): + assert not check_kind(ShapeSeq([IntType(32), ShapeSingleton(5)])) + assert not check_kind(ShapeProjection(TensorType(IntType(32), ShapeSeq([])), 0)) + assert not check_kind(ShapeBinaryOp(ShapeOp.SHPLUS, + ProductType([]), + TypeArrow(ProductType([]), ProductType([])))) + +def test_bad_quantifier(): + x = TypeId("x", Kind.Shape) + y = TypeId("y", Kind.Type) + z = TypeId("z", Kind.BaseType) + assert not check_kind(TypeQuantifier(x, x)) + assert not check_kind(TypeQuantifier(z, z)) + assert not check_kind(TypeQuantifier(y, TensorType(y, y))) + assert not check_kind(TypeQuantifier(x, TensorType(x, IntType(32)))) + assert not check_kind(TypeQuantifier(z, TensorType(ShapeSingleton(3), z))) diff --git a/relay/tests/python/test_pretty_print.py b/relay/tests/python/test_pretty_print.py new file mode 100644 index 0000000000000..27868345c30bc --- /dev/null +++ b/relay/tests/python/test_pretty_print.py @@ -0,0 +1,32 @@ +"""Tests pretty-printing of the AST.""" +from typing import no_type_check +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +import relay.pretty_printer as pp +from relay.frontend import compile_func, get_env +from relay.typing import Int + +def tprettyprint(prog): + relay_fn = compile_func(prog) + return pp.pretty_print(get_env(), relay_fn.body) + +@no_type_check +def one_plus_one() -> Int: + return 1 + 1 + +def test_pp_binop(): + print("print result:" + tprettyprint(one_plus_one)) + #assert tprettyprint(one_plus_one) == "not the result" + #todo(M.K.) fix test + +@no_type_check +def neg_ten() -> Int: + return -10 + +def test_pp_uop(): + print("print result:" + tprettyprint(neg_ten)) + #assert tprettyprint(neg_ten) == "not the result" + #todo(M.K.) fix test + +def test_pp_pair(): + pp.print_expr(get_env(), ProductLit([FloatLit(1.0), IntLit(2)])) diff --git a/relay/tests/python/test_reverse_ad.py b/relay/tests/python/test_reverse_ad.py new file mode 100644 index 0000000000000..33f75ba49b743 --- /dev/null +++ b/relay/tests/python/test_reverse_ad.py @@ -0,0 +1,21 @@ +"""Tests reverse-mode autodiff.""" +from relay.ir.expr import BOp +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +import relay.reverse_ad as ad +import relay.pretty_printer as pp +from relay.frontend import get_env + +def float_type(): + return TensorType(FloatType(32), ShapeSeq([])) + +# pylint: disable=invalid-name +def test_ad(): + curr_env = get_env() + x = LocalId("x") + y = FloatLit(12.0) + z = LocalId("z") + lt = Let(z, float_type(), BinaryOp(BOp.MUL, x, y), BinaryOp(BOp.PLUS, z, z)) + adres = ad.reverse_ad(curr_env, Function([Param(x, float_type())], float_type(), lt)) + print(pp.pretty_print(curr_env, adres)) + # assert isinstance(adres, expr.ProductLit) todo(M.K.) fix by asserting type checking diff --git a/relay/tests/python/test_shape_evaluator.py b/relay/tests/python/test_shape_evaluator.py new file mode 100644 index 0000000000000..afd880349b253 --- /dev/null +++ b/relay/tests/python/test_shape_evaluator.py @@ -0,0 +1,153 @@ +"""Tests the shape evaluator.""" +# pylint: disable=broad-except, invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.ir import ShapeOp, Kind +from relay.shape_evaluator import evaluate_concrete_shape + +def empty_attrs(): + return Attributes({}) + +def eval_empty(sh): + return evaluate_concrete_shape(sh, empty_attrs()) + +def try_eval_exception(sh, curr_env=Attributes({})): + try: + evaluate_concrete_shape(sh, curr_env) + assert False + except Exception: + return + +def test_eval_singleton(): + for i in range(10): + s = eval_empty(ShapeSingleton(i)) + assert s == ShapeSeq([ShapeSingleton(i)]) + +def test_eval_attr(): + for i in range(10): + s = evaluate_concrete_shape(ShapeAttr(String("x")), Attributes({String("x") : IntLit(i)})) + assert s == ShapeSeq([ShapeSingleton(i)]) + +def test_eval_simple_seq(): + seq = ShapeSeq([ShapeSingleton(i) for i in range(5)]) + assert seq == eval_empty(seq) + +def test_eval_seq_of_seq(): + seq = ShapeSeq([ShapeSingleton(i) for i in range(5)]) + seq_of_seq = ShapeSeq([seq for i in range(5)]) + compound = ShapeSeq([ShapeSingleton(i % 5) for i in range(25)]) + + assert compound == eval_empty(seq_of_seq) + + seq3 = ShapeSeq([seq_of_seq for i in range(5)]) + compound3 = ShapeSeq([ShapeSingleton(i % 5) for i in range(125)]) + assert compound3 == eval_empty(seq3) + +def test_eval_seq_with_attr(): + x = String("x") + seq = ShapeSeq([ShapeSingleton(1), ShapeAttr(x), ShapeSingleton(3)]) + final = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)]) + + assert final == evaluate_concrete_shape(seq, Attributes({x : IntLit(2)})) + +def test_proj_singleton(): + single = ShapeSingleton(5) + proj = ShapeProjection(single, 0) + assert ShapeSeq([single]) == eval_empty(proj) + +def test_proj_seq(): + seq = ShapeSeq([ShapeSingleton(i) for i in range(10)]) + for i in range(10): + proj = ShapeProjection(seq, i) + assert ShapeSeq([ShapeSingleton(i)]) == eval_empty(proj) + +def test_binary_op_singleton(): + s1 = ShapeSingleton(4) + s2 = ShapeSingleton(2) + + add = ShapeBinaryOp(ShapeOp.SHPLUS, s1, s2) + sub = ShapeBinaryOp(ShapeOp.SHSUB, s1, s2) + mul = ShapeBinaryOp(ShapeOp.SHMUL, s1, s2) + div = ShapeBinaryOp(ShapeOp.SHDIV, s1, s2) + + assert ShapeSeq([ShapeSingleton(6)]) == eval_empty(add) + assert ShapeSeq([ShapeSingleton(2)]) == eval_empty(sub) + assert ShapeSeq([ShapeSingleton(8)]) == eval_empty(mul) + assert ShapeSeq([ShapeSingleton(2)]) == eval_empty(div) + +def test_binary_op_seq(): + seq1 = ShapeSeq([ShapeSingleton(1) for i in range(10)]) + seq2 = ShapeSeq([ShapeSingleton(2) for i in range(10)]) + + add = ShapeBinaryOp(ShapeOp.SHPLUS, seq1, seq2) + sub = ShapeBinaryOp(ShapeOp.SHSUB, seq2, seq1) + mul = ShapeBinaryOp(ShapeOp.SHMUL, seq1, seq2) + div = ShapeBinaryOp(ShapeOp.SHDIV, seq2, seq2) + + seq3 = ShapeSeq([ShapeSingleton(3) for i in range(10)]) + assert seq3 == eval_empty(add) + assert seq1 == eval_empty(sub) + assert seq2 == eval_empty(mul) + assert seq1 == eval_empty(div) + +def test_seq_of_length_1(): + single1 = ShapeSingleton(1) + single2 = ShapeSingleton(2) + + seq1 = ShapeSeq([single1]) + seq2 = ShapeSeq([single2]) + + assert seq1 == eval_empty(seq1) + assert seq2 == eval_empty(ShapeBinaryOp(ShapeOp.SHPLUS, seq1, seq1)) + +def test_typevar_and_id(): + tv = TypeVar(1) + tid = TypeId("id", Kind.Shape) + + seq = ShapeSeq([ShapeSingleton(1), tid, tv, ShapeSingleton(2)]) + proj1 = ShapeProjection(tid, 1) + proj2 = ShapeProjection(tid, 2) + op1 = ShapeBinaryOp(ShapeOp.SHPLUS, tid, tv) + op2 = ShapeBinaryOp(ShapeOp.SHSUB, tid, + ShapeSeq([ShapeSingleton(1), + ShapeSingleton(2)])) + op3 = ShapeBinaryOp(ShapeOp.SHMUL, ShapeSeq([ShapeSingleton(1), + ShapeSingleton(2)]), + tv) + + assert seq == eval_empty(seq) + assert proj1 == eval_empty(proj1) + assert proj2 == eval_empty(proj2) + assert op1 == eval_empty(op1) + assert op2 == eval_empty(op2) + assert op3 == eval_empty(op3) + +def test_undefined_attr_error(): + x = String("x") + seq = ShapeSeq([ShapeSingleton(1), ShapeAttr(x), ShapeSingleton(2)]) + try_eval_exception(seq) + +def test_proj_out_of_bounds(): + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(i) for i in range(5)]) + + proj1 = ShapeProjection(single, 1) + proj2 = ShapeProjection(seq, 5) + try_eval_exception(proj1) + try_eval_exception(proj2) + +def test_shape_arithmetic_errors(): + negative = ShapeBinaryOp(ShapeOp.SHSUB, ShapeSingleton(1), ShapeSingleton(2)) + div_by_zero = ShapeBinaryOp(ShapeOp.SHDIV, ShapeSingleton(3), ShapeSingleton(0)) + try_eval_exception(negative) + try_eval_exception(div_by_zero) + +def test_shape_binary_op_different_ranks(): + single = ShapeSingleton(5) + seq1 = ShapeSeq([ShapeSingleton(i) for i in range(5)]) + seq2 = ShapeSeq([ShapeSingleton(i) for i in range(6)]) + + add1 = ShapeBinaryOp(ShapeOp.SHPLUS, single, seq1) + add2 = ShapeBinaryOp(ShapeOp.SHPLUS, seq1, seq2) + try_eval_exception(add1) + try_eval_exception(add2) diff --git a/relay/tests/python/test_softmax.py b/relay/tests/python/test_softmax.py new file mode 100644 index 0000000000000..a8db79b7e8859 --- /dev/null +++ b/relay/tests/python/test_softmax.py @@ -0,0 +1,35 @@ +# type: ignore +# pylint: skip-file +import tvm +from tvm._ffi.ndarray import _make_array +from tvm._ffi.runtime_ctypes import TVMArrayHandle +import relay.ir as expr +import numpy as np +from relay.make import * +from relay.visitor import Visitor +from relay.typing import Float, Tensor, Int +import relay.eval as re +import ctypes +from relay.frontend import relay, get_env +# import nnvm.relay.softmax as sm +import relay.pretty_printer as pp + +def test_linear(): + pass # FIX ME + # # import pdb; pdb.set_trace() + # l = sm.linear() + # print(pp.pretty_print(get_env(), l)) + # print(re.eval(get_env(), Call(l, [FloatLit(3.0), FloatLit(4.0), FloatLit(5.0)]))) + +# agenda: +# generalize to float4, probably need to touch tvm +# write softmax, also need to touch tvm to include softmax + +# note: reverse ad beed to be touch to handle intrinsic. +# here is a brief discussion of the internal working of reverse mode. +# so the basic idea is, like forward mode automatic differentiation, +# if we can get a local transformation, we only need to care for Real and operation on Real. +# for generic construct like control flow, lambda, we just need to lift the type. +# and we can obtain a local transformation on Real, by replacing every Real with a triple: Real, Ref Real, () -> (), dubbed backproper +# the Real is the original value, Ref Real is the stored gradient, and () -> () will clear the gradient, and pass it all the way up to the outmost parameters, by chaining backproper. +# note that each backproper should be called once, since we do not reset gradient after calling. Or alternatively we can fix it. diff --git a/relay/tests/python/test_span.py b/relay/tests/python/test_span.py new file mode 100644 index 0000000000000..0890e94edf6f5 --- /dev/null +++ b/relay/tests/python/test_span.py @@ -0,0 +1,33 @@ +"""Tests setting spans and the like.""" +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * + +def test_set_expr_span(): + x = IntLit(5) + sp1 = Span(FileId(0), 0, 0) + x.set_span(sp1) + assert x.span == sp1 + sp2 = Span(FileId(1), 2, 3) + x.set_span(sp2) + assert x.span == sp2 + +def test_set_type_span(): + x = IntType(32) + sp1 = Span(FileId(0), 0, 0) + x.set_span(sp1) + assert x.span == sp1 + sp2 = Span(FileId(1), 2, 3) + x.set_span(sp2) + assert x.span == sp2 + +def test_set_item_span(): + x = Defn(GlobalId("x"), + TensorType(IntType(32), ShapeSeq([])), + IntLit(5)) + sp1 = Span(FileId(0), 0, 0) + x.set_span(sp1) + assert x.span == sp1 + sp2 = Span(FileId(1), 2, 3) + x.set_span(sp2) + assert x.span == sp2 diff --git a/relay/tests/python/test_tyck.py b/relay/tests/python/test_tyck.py new file mode 100644 index 0000000000000..d8fc65963c0f5 --- /dev/null +++ b/relay/tests/python/test_tyck.py @@ -0,0 +1,404 @@ +"""Tests typechecking.""" +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +import relay.tyck as tc +from relay.ir import UOp, BOp, Kind + +def check_new_env(exp): + return tc.check_expr(Environment({}), exp) + +def int_type(): + return TensorType(IntType(32), ShapeSeq([])) + +def float_type(): + return TensorType(FloatType(32), ShapeSeq([])) + +def bool_type(): + return TensorType(BoolType(), ShapeSeq([])) + +def unit_type(): + return ProductType([]) + +# pylint: disable=broad-except +def check_exception(test_expr): + try: + check_new_env(test_expr) + assert False + except Exception: + return + +def test_primitive_lit_sanity(): + assert check_new_env(IntLit(5)) == int_type() + assert check_new_env(FloatLit(2.5)) == float_type() + assert check_new_env(BoolLit(True)) == bool_type() + +def test_tensor_lit(): + tensorty1 = TensorType(IntType(32), ShapeSeq([ShapeSingleton(5)])) + tensorty2 = TensorType(IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(5)])) + + lit1 = TensorLit([IntLit(1), IntLit(3), IntLit(5), IntLit(7), IntLit(9)]) + ty1 = check_new_env(lit1) + assert ty1 == tensorty1 + + lit2 = TensorLit([lit1, lit1, lit1]) + ty2 = check_new_env(lit2) + assert ty2 == tensorty2 + +def test_if_typecheck(): + exp = If(BoolLit(True), IntLit(12), IntLit(32)) + ty = check_new_env(exp) + assert ty == int_type() + +def test_uop_typecheck(): + exp1 = UnaryOp(UOp.NEG, IntLit(12)) + ty1 = check_new_env(exp1) + assert ty1 == int_type() + + exp2 = UnaryOp(UOp.SQ, FloatLit(1.0)) + ty2 = check_new_env(exp2) + assert ty2 == float_type() + +def test_binop(): + exp1 = BinaryOp(BOp.PLUS, IntLit(12), IntLit(14)) + ty1 = check_new_env(exp1) + assert ty1 == int_type() + + exp2 = BinaryOp(BOp.DIV, FloatLit(5.0), FloatLit(2.0)) + ty2 = check_new_env(exp2) + assert ty2 == float_type() + +def test_binop_tensor(): + base_tensor = TensorLit([IntLit(i) for i in range(5)]) + t1 = TensorLit([base_tensor for i in range(3)]) + tensor_type = TensorType(IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(5)])) + + exp = BinaryOp(BOp.PLUS, t1, t1) + assert tensor_type == check_new_env(exp) + +def test_binop_comparison(): + exp1 = BinaryOp(BOp.NE, IntLit(23), IntLit(24)) + assert bool_type() == check_new_env(exp1) + + base_tensor = TensorLit([IntLit(i) for i in range(5)]) + t1 = TensorLit([base_tensor for i in range(3)]) + exp2 = BinaryOp(BOp.LE, t1, t1) + tensor_type = TensorType(BoolType(), ShapeSeq([ShapeSingleton(3), ShapeSingleton(5)])) + assert tensor_type == check_new_env(exp2) + +def test_let_typecheck(): + x = LocalId("x") + + exp1 = Let(x, bool_type(), BoolLit(True), IntLit(32)) + ty1 = check_new_env(exp1) + assert ty1 == int_type() + + exp2 = Let(x, bool_type(), BoolLit(True), x) + ty2 = check_new_env(exp2) + assert ty2 == bool_type() + +def test_let_recursive_function(): + x = LocalId("x") + y = LocalId("y") + + # extremely dumb parity test haha + is_even = Function([Param(x, int_type())], bool_type(), + If(BinaryOp(BOp.GT, x, IntLit(1)), + UnaryOp(UOp.NEG, Call(y, [BinaryOp(BOp.SUB, x, IntLit(1))])), + BinaryOp(BOp.EQ, x, IntLit(0)))) + + exp = Let(y, TypeArrow(ProductType([int_type()]), bool_type()), is_even, Call(y, [IntLit(3)])) + ty = check_new_env(exp) + assert ty == bool_type() + +def test_function_node(): + x = LocalId("x") + params = [Param(x, int_type())] + fun = Function(params, int_type(), BinaryOp(BOp.PLUS, x, x)) + funty = TypeArrow(ProductType([int_type()]), int_type()) + ty = check_new_env(fun) + assert ty == funty + +def test_function_call(): + x = LocalId("x") + y = LocalId("y") + z = LocalId("z") + params = [Param(x, bool_type()), + Param(y, int_type()), + Param(z, int_type())] + fun = Function(params, int_type(), + If(x, + BinaryOp(BOp.PLUS, y, z), + BinaryOp(BOp.SUB, z, y))) + args = [BoolLit(True), IntLit(12), IntLit(32)] + ty = check_new_env(Call(fun, args)) + assert ty == int_type() + +def test_product_lit(): + prod = ProductLit([IntLit(12), BoolLit(False), FloatLit(2.0)]) + + tupty = ProductType([int_type(), bool_type(), float_type()]) + + ty = check_new_env(prod) + assert ty == tupty + +def test_projection(): + prod = ProductLit([IntLit(12), BoolLit(False), FloatLit(2.0)]) + + ty1 = check_new_env(Projection(prod, 0)) + assert ty1 == int_type() + ty2 = check_new_env(Projection(prod, 1)) + assert ty2 == bool_type() + ty3 = check_new_env(Projection(prod, 2)) + assert ty3 == float_type() + +def test_ref_lit(): + reflit = Ref(BinaryOp(BOp.PLUS, IntLit(2), IntLit(3))) + rt = RefType(int_type()) + ty = check_new_env(reflit) + assert ty == rt + +def test_val_ref(): + reflit = Ref(FloatLit(2.5)) + valref = ValRef(reflit) + ty = check_new_env(valref) + assert ty == float_type() + +def test_set_ref(): + reflit = Ref(BoolLit(False)) + setref = SetRef(reflit, BoolLit(True)) + ty = check_new_env(setref) + assert ty == unit_type() + +def test_cast(): + cast1 = Cast(int_type(), FloatLit(3.1)) + assert check_new_env(cast1) == int_type() + + tt = TensorType(IntType(32), ShapeSeq([ShapeSingleton(5)])) + cast2 = Cast(tt, IntLit(3)) + assert check_new_env(cast2) == tt + +def test_zero(): + zero1 = Zero(float_type()) + assert check_new_env(zero1) == float_type() + + tt = TensorType(BoolType(), ShapeSeq([ShapeSingleton(i) for i in range(5)])) + zero2 = Zero(tt) + assert check_new_env(zero2) == tt + +def test_call_with_shape_attrs(): + dim_attr = "new_dim" + btvar = TypeId("bt", Kind.BaseType) + input_shape = TypeId("input_shape", Kind.Shape) + output_shape = ShapeSeq([input_shape, ShapeAttr(String(dim_attr))]) + + ftype = TypeQuantifier( + btvar, + TypeQuantifier( + input_shape, + TypeArrow(ProductType([TensorType(btvar, input_shape)]), + TensorType(btvar, output_shape)))) + + # type is TensorType(FloatType(32), ShapeSeq([ShapeSingleton(2), ShapeSingleton(3)])) + actual_input1 = TensorLit([TensorLit([FloatLit(2.1), FloatLit(3.3), FloatLit(4.4)]), + TensorLit([FloatLit(6.3), FloatLit(7.8), FloatLit(8.9)])]) + # type is TensorType(IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(4)])) + actual_input2 = TensorLit([ + TensorLit([IntLit(2), IntLit(3), IntLit(4), IntLit(5)]), + TensorLit([IntLit(6), IntLit(7), IntLit(9), IntLit(6)]), + TensorLit([IntLit(2), IntLit(3), IntLit(4), IntLit(7)]), + ]) + + concrete_output_type1 = TensorType(FloatType(32), + ShapeSeq([ShapeSingleton(2), + ShapeSingleton(3), + ShapeSingleton(4)])) + + concrete_output_type2 = TensorType(IntType(32), + ShapeSeq([ShapeSingleton(3), + ShapeSingleton(4), + ShapeSingleton(6)])) + + callee = LocalId("callee") + caller1 = Function([Param(callee, ftype)], + concrete_output_type1, + Call(callee, [actual_input1], Attributes({String(dim_attr) : IntLit(4)}))) + + caller2 = Function([Param(callee, ftype)], + concrete_output_type2, + Call(callee, [actual_input2], Attributes({String(dim_attr) : IntLit(6)}))) + + + assert check_new_env(caller1) == TypeArrow(ProductType([ftype]), concrete_output_type1) + assert check_new_env(caller2) == TypeArrow(ProductType([ftype]), concrete_output_type2) + +# negative tests +def test_basetypes_incompatible(): + boollit = BoolLit(False) + intlit = IntLit(2) + floatlit = FloatLit(2.3) + + assert check_new_env(boollit) != int_type() and check_new_env(boollit) != float_type() + assert check_new_env(intlit) != bool_type() and check_new_env(intlit) != float_type() + assert check_new_env(floatlit) != bool_type() and check_new_env(floatlit) != int_type() + +def test_tensor_lit_inconsistent_shape(): + t1 = TensorLit([TensorLit([IntLit(3), IntLit(4)]), + TensorLit([IntLit(5), IntLit(6), IntLit(7)])]) + check_exception(t1) + +def test_tensor_lit_inconsistent_data_types(): + t1 = TensorLit([IntLit(3), FloatLit(2.5), BoolLit(True)]) + check_exception(t1) + +def test_invalid_projection_index(): + p1 = ProductLit([IntLit(3), IntLit(5)]) + check_exception(Projection(p1, 2)) + +def test_invalid_projection_tuple_argument(): + t1 = TensorLit([IntLit(3), IntLit(5)]) + check_exception(Projection(t1, 0)) + +def test_if_node_invalid_guard(): + ifn = If(FloatLit(-3.5), IntLit(5), IntLit(6)) + check_exception(ifn) + +def test_if_node_branches_incompatible(): + intlit = IntLit(2) + floatlit = FloatLit(3.6) + check_exception(If(BoolLit(False), intlit, floatlit)) + check_exception(If(BoolLit(False), floatlit, intlit)) + +def test_binop_incompatible_dtypes(): + intlit = IntLit(2) + floatlit = FloatLit(2.3) + check_exception(BinaryOp(BOp.PLUS, intlit, floatlit)) + check_exception(BinaryOp(BOp.PLUS, floatlit, intlit)) + + t1 = TensorLit([IntLit(i) for i in range(5)]) + t2 = TensorLit([FloatLit(float(i)) for i in range(5)]) + check_exception(BinaryOp(BOp.PLUS, t1, t2)) + check_exception(BinaryOp(BOp.PLUS, t2, t1)) + +def test_binop_incompatible_shapes(): + t1 = TensorLit([TensorLit([IntLit(i), IntLit(i + 1)]) for i in range(5)]) + t2 = TensorLit([TensorLit([IntLit(i), IntLit(i + 1), IntLit(i + 2)]) for i in range(5)]) + check_exception(BinaryOp(BOp.SUB, t1, t2)) + check_exception(BinaryOp(BOp.SUB, t2, t1)) + +def test_binop_non_tensor(): + p1 = ProductLit([IntLit(3), IntLit(4)]) + p2 = IntLit(5) + check_exception(BinaryOp(BOp.PLUS, p1, p2)) + check_exception(BinaryOp(BOp.PLUS, p2, p1)) + +def test_unop_non_tensor(): + p1 = ProductLit([IntLit(3), IntLit(4)]) + check_exception(UnaryOp(UOp.NEG, p1)) + check_exception(UnaryOp(UOp.SQ, p1)) + +def test_valref_nonref(): + check_exception(ValRef(IntLit(3))) + check_exception(ValRef(ProductLit([IntLit(3), FloatLit(3.5)]))) + +def test_setref_nonref(): + check_exception(SetRef(IntLit(3), IntLit(9))) + check_exception(SetRef(TensorLit([IntLit(3), IntLit(6)]), IntLit(9))) + +def test_setref_incompatible_types(): + r = Ref(TensorLit([IntLit(i) for i in range(5)])) + check_exception(SetRef(r, TensorLit([FloatLit(float(i)) for i in range(5)]))) + check_exception(SetRef(r, TensorLit([IntLit(i) for i in range(6)]))) + check_exception(SetRef(r, ProductLit([IntLit(i) for i in range(5)]))) + +def test_function_ret_body_mismatch(): + x = LocalId("x") + params = [Param(x, int_type())] + f = Function(params, int_type(), BinaryOp(BOp.NE, x, IntLit(3))) + check_exception(f) + +def test_function_type_mismatch_within_body(): + x = LocalId("x") + y = LocalId("y") + z = LocalId("z") + params = [Param(x, int_type()), Param(y, float_type())] + f = Function(params, bool_type(), Let(z, bool_type(), BinaryOp(BOp.EQ, x, y), z)) + check_exception(f) + +def test_call_non_function(): + check_exception(Call(IntLit(3), [BoolLit(True), FloatLit(2.5)])) + check_exception(Call(ProductLit([BoolLit(False), IntLit(3)]), + [BoolLit(True), FloatLit(2.5)])) + check_exception(Call(ProductLit([BoolLit(False), FloatLit(2.3)]), + [TensorLit([IntLit(3), IntLit(4)])])) + +def test_call_param_type_mismatch(): + x = LocalId("x") + y = LocalId("y") + params = [Param(x, int_type()), Param(y, bool_type())] + f = Function(params, int_type(), If(y, x, IntLit(0))) + + check_exception(Call(f, [FloatLit(3.5), BoolLit(False)])) + check_exception(Call(f, [IntLit(3), TensorLit([BoolLit(True), BoolLit(False)])])) + +def test_call_arity_mismatch(): + x = LocalId("x") + params = [Param(x, int_type())] + f = Function(params, int_type(), x) + + check_exception(Call(f, [])) + check_exception(Call(f, [IntLit(5), IntLit(3)])) + +def test_zero_nontensor(): + pt = ProductType([IntType(32), FloatType(32)]) + rt = RefType(int_type()) + bt = IntType(32) + qt = TypeQuantifier(TypeId("x", Kind.Type), bool_type(), int_type()) + check_exception(Zero(pt)) + check_exception(Zero(rt)) + check_exception(Zero(bt)) + check_exception(Zero(qt)) + +def test_let_nonmatching_annotation(): + x = LocalId("x") + + check_exception(Let(x, bool_type(), IntLit(5), x)) + check_exception(Let(x, int_type(), BoolLit(False), x)) + +def test_let_recursive_nonfunction(): + x = LocalId("x") + + check_exception(Let(x, int_type(), BinaryOp(BOp.PLUS, x, IntLit(1)), + BinaryOp(BOp.PLUS, x, IntLit(1)))) + check_exception(Let(x, ProductType([int_type(), int_type()]), + ProductLit([Projection(x, 0), Projection(x, 1)]), + Projection(x, 0))) + +def test_let_recursive_function_bad_use(): + x = LocalId("x") + y = LocalId("y") + + # nonsensical, don't think about it + # key thing is the recursive call should want a bool return + fun = Function([Param(x, int_type())], int_type(), + If(Call(y, [BinaryOp(BOp.DIV, x, IntLit(2))]), + BinaryOp(BOp.DIV, x, IntLit(2)), x)) + + check_exception(Let(y, TypeArrow(ProductType([int_type()]), int_type()), + fun, Call(y, [IntLit(3)]))) + +def test_let_recursive_function_mismatch(): + x = LocalId("x") + y = LocalId("y") + + fib = Function([Param(x, int_type())], int_type(), + If(BinaryOp(BOp.EQ, x, IntLit(0)), IntLit(1), + If(BinaryOp(BOp.EQ, x, IntLit(1)), IntLit(1), + BinaryOp(BOp.PLUS, + Call(y, [BinaryOp(BOp.SUB, x, IntLit(1))]), + Call(y, [BinaryOp(BOp.SUB, x, IntLit(2))])) + ))) + + check_exception(Let(y, + TypeArrow(ProductType([int_type(), int_type()]), int_type()), + fib, Call(y, [IntLit(3)]))) diff --git a/relay/tests/python/test_unifier.py b/relay/tests/python/test_unifier.py new file mode 100644 index 0000000000000..6189ecacdebe5 --- /dev/null +++ b/relay/tests/python/test_unifier.py @@ -0,0 +1,480 @@ +"""Tests unification of types.""" +# pylint: disable=invalid-name, missing-docstring, bare-except +import relay.ir as ir +# pylint: disable=unused-import +import relay.unifier # TODO (@jroesch) fix me +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * + +def unify_types(t1, t2): + unifier = TypeUnifier() + return unifier.unify(t1, t2) + +def int_type(): + return TensorType(IntType(32), ShapeSeq([])) + +def float_type(): + return TensorType(FloatType(32), ShapeSeq([])) + +def bool_type(): + return TensorType(BoolType(), ShapeSeq([])) + +def make_shape(dims): + return ShapeSeq([ShapeSingleton(dim) for dim in dims]) + +def test_insert_and_find(): + uf = UnionFind() + v1 = TypeVar(1) + v2 = TypeVar(2) + uf.insert(v1) + uf.insert(v2) + assert uf.find(v1) == v1 + assert uf.find(v2) == v2 + +def test_insert_error(): + uf = UnionFind() + v1 = TypeVar(1) + v2 = TypeVar(2) + uf.insert(v1) + try: + uf.find(v2) + assert False + except: + return + +def test_unify(): + uf = UnionFind() + v1 = TypeVar(1) + v2 = TypeVar(2) + v3 = TypeVar(3) + uf.insert(v1) + uf.insert(v2) + uf.insert(v3) + uf.unify(v1, v2) + rep = uf.find(v1) + assert (rep == v1 or rep == v2) + assert uf.find(v1) == rep + assert uf.find(v2) == rep + assert uf.find(v3) == v3 + assert v3 != rep + uf.unify(v1, v3) + new_rep = uf.find(v3) + assert (rep == v1 or rep == v2 or rep == v3) + assert uf.find(v1) == new_rep + assert uf.find(v2) == new_rep + assert uf.find(v3) == new_rep + +def test_unify_multiple_levels(): + uf = UnionFind() + v = [TypeVar(i) for i in range(9)] + for var in v: + uf.insert(var) + uf.unify(v[0], v[1]) + uf.unify(v[0], v[2]) + uf.unify(v[3], v[4]) + uf.unify(v[4], v[5]) + uf.unify(v[6], v[7]) + uf.unify(v[6], v[8]) + rep1 = uf.find(v[0]) + rep2 = uf.find(v[3]) + rep3 = uf.find(v[6]) + assert (rep1 == v[0] or rep1 == v[1] or rep1 == v[2]) + assert (rep2 == v[3] or rep2 == v[4] or rep2 == v[5]) + assert (rep3 == v[6] or rep3 == v[7] or rep3 == v[8]) + for i in range(3): + assert uf.find(v[i]) == rep1 + assert uf.find(v[i + 3]) == rep2 + assert uf.find(v[i + 6]) == rep3 + # now unify two of the groups + uf.unify(v[1], v[4]) + new_rep1 = uf.find(v[0]) + new_rep2 = uf.find(v[6]) + assert (new_rep1 == v[0] or new_rep1 == v[1] or new_rep1 == v[2] + or new_rep1 == v[3] or new_rep1 == v[4] or new_rep1 == v[5]) + assert (new_rep2 == v[6] or new_rep2 == v[7] or new_rep2 == v[8]) + for i in range(6): + assert uf.find(v[i]) == new_rep1 + for i in range(3): + assert uf.find(v[i + 6]) == new_rep2 + +# TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work +def test_unify_int(): + intty = IntType(1) + unified = unify_types(intty, intty) + assert intty == unified + +def test_unify_bool(): + boolty = BoolType() + unified = unify_types(boolty, boolty) + assert boolty == unified + +def test_unify_float(): + floatty = FloatType(4) + unified = unify_types(floatty, floatty) + assert floatty == unified + +def test_unify_incompatible_basetypes(): + bt = BoolType() + intty = IntType(32) + try: + unify_types(bt, intty) + assert False + except: + return + +def test_unify_concrete_type_arrow(): + arr1 = TypeArrow(int_type(), int_type()) + arr2 = TypeArrow(int_type(), int_type()) + unified = unify_types(arr1, arr2) + assert unified == arr1 + +def test_unify_type_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(1) + unifier.insert(v1) + unifier.unify(v1, bool_type()) + arr1 = TypeArrow(int_type(), bool_type()) + arr2 = TypeArrow(int_type(), v1) + unified = unifier.unify(arr1, arr2) + assert unified == arr1 + + v2 = TypeVar(2) + unifier.insert(v2) + unifier.unify(v2, int_type()) + arr3 = TypeArrow(v2, bool_type()) + unified = unifier.unify(arr1, arr3) + assert unified == arr1 + +def test_reject_incompatible_type_arrows(): + arr1 = TypeArrow(int_type(), bool_type()) + arr2 = TypeArrow(ProductType([int_type(), bool_type()]), bool_type()) + try: + unify_types(arr1, arr2) + assert False + except: + return + +def test_unify_concrete_type_quantifiers(): + tq1 = TypeQuantifier(TypeId("id1"), int_type()) + tq2 = TypeQuantifier(TypeId("id2"), int_type()) + unified = unify_types(tq1, tq2) + assert unified == tq1 + +def test_unify_basetype_with_quantifier_error(): + bt = bool_type() + tq = TypeQuantifier(TypeId("id1"), bt) + try: + unify_types(bt, tq) + assert False + except: + return + +def test_unify_typevars_with_each_other(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + v3 = TypeVar(3) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + unified = unifier.unify(v1, v2) + assert (unified == v1 or unified == v2) + assert unified != v3 + new_unified = unifier.unify(v1, v3) + assert (new_unified == v1 or new_unified == v2 or new_unified == v3) + +def test_unify_typevars_with_basetype(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(1) + v2 = TypeVar(2) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(v1, bt) + assert unified1 == bt + unified2 = unifier.unify(v1, v2) + assert unified2 == bt + +def test_unify_compatible_typevars(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(1) + v2 = TypeVar(2) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, bt) + # because types to which v1 and v2 have been assigned are compatible, + # this should proceed without problems + unified = unifier.unify(v1, v2) + assert unified == bt + +def test_unify_incompatible_typevars(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + bt = bool_type() + tq = TypeQuantifier(TypeId("id1"), bt) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, tq) + # bt cannot be unified with tq, so unifying v1 and v2 should give an error + try: + unifier.unify(v1, v2) + assert False + except: + return + +def test_unify_typevar_with_quantifier(): + unifier = TypeUnifier() + tq = TypeQuantifier(TypeId("id1"), bool_type()) + v1 = TypeVar(1) + unifier.insert(v1) + unified = unifier.unify(v1, tq) + assert unified == tq + +def test_unify_typevars_inside_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(1) + unifier.insert(v1) + tq1 = TypeQuantifier(TypeId("id1"), v1) + tq2 = TypeQuantifier(TypeId("id2"), bool_type()) + unified = unifier.unify(tq1, tq2) + assert unified == tq2 + +def test_unify_concrete_tensors(): + bt = BoolType() + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt, shape) + tt2 = TensorType(bt, shape) + unified = unify_types(tt1, tt2) + assert unified == tt1 + +def test_unify_tensor_shape_reject(): + bt = BoolType() + shape1 = make_shape([1, 2, 3]) + shape2 = make_shape([2, 3, 4]) + tt1 = TensorType(bt, shape1) + tt2 = TensorType(bt, shape2) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_tensor_dtype_reject(): + bt1 = BoolType() + bt2 = IntType(32) + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt1, shape) + tt2 = TensorType(bt2, shape) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_quantified_tensors(): + x = TypeId("x", ir.type.Kind.Shape) + y = TypeId("y", ir.type.Kind.Shape) + tq1 = TypeQuantifier(x, TensorType(BoolType(), x)) + tq2 = TypeQuantifier(y, TensorType(BoolType(), y)) + unified = unify_types(tq1, tq2) + assert unified == tq1 + + a = TypeId("a", ir.type.Kind.BaseType) + b = TypeId("b", ir.type.Kind.BaseType) + tq3 = TypeQuantifier(a, TensorType(a, make_shape([1, 2, 3]))) + tq4 = TypeQuantifier(b, TensorType(b, make_shape([1, 2, 3]))) + unified = unify_types(tq3, tq4) + assert unified == tq3 + +def test_unify_concrete_products(): + bt = bool_type() + intty = int_type() + pt1 = ProductType([bt, intty]) + pt2 = ProductType([bt, intty]) + unified = unify_types(pt1, pt2) + assert unified == pt1 + +def test_unify_products_reject_size(): + bt = BoolType() + intty = IntType(32) + pt1 = ProductType([bt, bt, intty]) + pt2 = ProductType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_reject_member(): + bt = BoolType() + intty = IntType(32) + pt1 = ProductType([bt, bt]) + pt2 = ProductType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(1) + bt = bool_type() + pt1 = ProductType([bt, bt]) + pt2 = ProductType([v1, bt]) + unifier.insert(v1) + unified = unifier.unify(pt1, pt2) + assert unified == pt1 + +def test_unify_quantified_products(): + x = TypeId("x") + y = TypeId("y") + p1 = TypeQuantifier(x, ProductType([int_type(), x])) + p2 = TypeQuantifier(y, ProductType([int_type(), y])) + unified = unify_types(p1, p2) + assert unified == p1 + +def test_unify_ref_types(): + r1 = RefType(bool_type()) + r2 = RefType(bool_type()) + assert unify_types(r1, r2) == r1 + +def test_unify_ref_reject_inner(): + r1 = RefType(BoolType()) + r2 = RefType(IntType(32)) + try: + unify_types(r1, r2) + assert False + except: + return + +def test_subst_basetype(): + unifier = TypeUnifier() + bt = BoolType() + assert bt == unifier.subst(bt) + +def test_subst_simple_hole(): + unifier = TypeUnifier() + v1 = TypeVar(1) + bt = BoolType() + unifier.insert(v1) + unifier.unify(v1, bt) + assert unifier.subst(v1) == bt + +def test_subst_typevar_for_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + unifier.insert(v1) + unifier.insert(v2) + + unifier.unify(v1, v2) + assert unifier.subst(v1) == v2 + +def test_subst_concrete_arrow(): + unifier = TypeUnifier() + arr1 = TypeArrow(int_type(), int_type()) + assert unifier.subst(arr1) == arr1 + +def test_subst_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, int_type()) + unifier.unify(v2, bool_type()) + arr1 = TypeArrow(v1, v2) + arr2 = TypeArrow(int_type(), bool_type()) + assert unifier.subst(arr1) == arr2 + +def test_subst_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(1) + tq = TypeQuantifier(TypeId("id1"), int_type()) + unifier.insert(v1) + unifier.unify(v1, tq) + assert unifier.subst(v1) == tq + +def test_subst_quantifier_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + tq1 = TypeQuantifier(TypeId("id1"), v2) + intty = int_type() + tq2 = TypeQuantifier(TypeId("id2"), intty) + + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v2, intty) + unifier.unify(v1, tq1) + assert unifier.subst(v1) == tq2 + +def test_subst_concrete_tensor(): + unifier = TypeUnifier() + v1 = TypeVar(1) + unifier.insert(v1) + tt = TensorType(BoolType(), make_shape([1, 2, 3])) + unifier.unify(v1, tt) + assert unifier.subst(v1) == tt + +def test_subst_concrete_product(): + unifier = TypeUnifier() + v1 = TypeVar(1) + unifier.insert(v1) + bt = bool_type() + pt = ProductType([bt, bt]) + unifier.unify(v1, pt) + assert unifier.subst(v1) == pt + +def test_subst_product_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + v3 = TypeVar(3) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + + bt = bool_type() + intty = int_type() + pt1 = ProductType([bt, v2, v3]) + unifier.unify(v2, intty) + unifier.unify(v3, v2) + unifier.unify(v1, pt1) + pt2 = ProductType([bt, intty, intty]) + assert unifier.subst(v1) == pt2 + +def test_subst_concrete_ref(): + unifier = TypeUnifier() + rt = RefType(bool_type()) + assert unifier.subst(rt) == rt + +def test_subst_ref_with_hole(): + unifier = TypeUnifier() + v1 = TypeVar(1) + unifier.insert(v1) + + unifier.unify(v1, bool_type()) + rt1 = RefType(v1) + rt2 = RefType(bool_type()) + assert unifier.subst(rt1) == rt2 + +def test_typevar_on_lhs(): + unifier = TypeUnifier() + v1 = TypeVar(1) + v2 = TypeVar(2) + bt = bool_type() + tq = TypeQuantifier(TypeId("id1"), bt, bt) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(bt, v1) + assert unified1 == bt + unified2 = unifier.unify(tq, v2) + assert unified2 == tq + assert unifier.subst(v1) == bt + assert unifier.subst(v2) == tq diff --git a/relay/tests/python/test_visitor.py b/relay/tests/python/test_visitor.py new file mode 100644 index 0000000000000..3f6be942f9d81 --- /dev/null +++ b/relay/tests/python/test_visitor.py @@ -0,0 +1,214 @@ +"""Tests the Python visitor for the Relay AST.""" +# pylint: disable=undefined-variable, invalid-name +import relay.ir as ir +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.visitor import Visitor + +# pylint: disable=missing-docstring, arguments-differ +class SVisitor(Visitor): + def visit_local_id(self, name, *_): + assert name == "foo" + + def visit_global_id(self, name, *_): + assert name == "glob" + + def visit_intrinsic_id(self, name, *_): + assert name == "intr" + + def visit_float_lit(self, value, *_): + assert value == 1.0 + + def visit_bool_lit(self, value, *_): + assert value + + def visit_int_lit(self, value, *_): + assert value == 1337 + + def visit_tensor_lit(self, data, *_): + assert data[0] == IntLit(1) + assert data[1] == IntLit(2) + assert data[2] == IntLit(3) + + def visit_product_lit(self, fields, *_): + assert list(fields) == [IntLit(1), IntLit(2)] + + def visit_cast(self, target, node, *_): + assert target == IntType(12) + assert node == FloatLit(145.0) + + def visit_param(self, ident, typ, *_): + assert isinstance(ident, ir.LocalId) + assert ident.name == "x" + assert typ == TensorType(IntType(32), ShapeSeq([])) + + def visit_function(self, params, ret_type, body, *_): + assert len(params) == 2 + assert ret_type == TensorType(IntType(32), ShapeSeq([])) + assert body == IntLit(12) + + def visit_call(self, fn, call_args, *_): + assert len(fn.params) == 1 + assert fn.ret_type == ProductType([]) + assert len(call_args) == 1 + assert call_args[0] == ProductLit([]) + + def visit_debug(self, node): + assert False + + def visit_unary_op(self, op, op_arg, *_): + assert op == ir.UOp.SQ + assert op_arg == IntLit(12) + + def visit_binary_op(self, op, left, right, *_): + assert op == ir.BOp.SUB + assert left == IntLit(15) + assert right == IntLit(3) + + def visit_reverse(self, node, *_): + assert node == IntLit(1) + + def visit_zero(self, typ, *_): + assert typ == TensorType(IntType(32), ShapeSeq([])) + + def visit_projection(self, tupl, field, *_): + assert tupl == ProductLit([IntLit(3), BoolLit(False)]) + assert field == 1 + + def visit_if(self, guard, true_b, false_b, *_): + assert guard == BoolLit(True) + assert true_b == IntLit(3) + assert false_b == IntLit(9) + + def visit_let(self, ident, typ, value, body, *_): + assert body == ident + assert typ == TensorType(IntType(32), ShapeSeq([])) + assert value == IntLit(12) + + def visit_ref(self, expr, *_): + assert expr == ProductLit([]) + + def visit_val_ref(self, ref, *_): + assert ref == Ref(IntLit(12)) + + def visit_set_ref(self, ref, val, *_): + assert ref == Ref(ProductLit([IntLit(5), IntLit(7)])) + assert val == ProductLit([IntLit(6), IntLit(6)]) + + def visit_gradient(self, node, *_): + assert node == FloatLit(2.0) + +def test_visit_local_id(): + tv = SVisitor() + tv.run(LocalId("foo")) + +def test_visit_global_id(): + tv = SVisitor() + tv.run(GlobalId("glob")) + +def test_visit_intrinsic_id(): + tv = SVisitor() + tv.run(IntrinsicId("intr")) + +def test_visit_float_lit(): + tv = SVisitor() + tv.run(FloatLit(1.0)) + +def test_visit_bool_lit(): + tv = SVisitor() + tv.run(BoolLit(True)) + +def test_visit_int_lit(): + tv = SVisitor() + tv.run(IntLit(1337)) + +def test_visit_tensor_lit(): + tv = SVisitor() + data = [IntLit(1), IntLit(2), IntLit(3)] + tv.run(TensorLit(data)) + +def test_visit_product_lit(): + tv = SVisitor() + data = [IntLit(1), IntLit(2)] + tv.run(ProductLit(data)) + +def test_visit_cast(): + tv = SVisitor() + tv.run(Cast(IntType(12), FloatLit(145.0))) + +def test_visit_param(): + tv = SVisitor() + param = Param(LocalId("x"), TensorType(IntType(32), ShapeSeq([]))) + tv.run(param) + +def test_visit_function(): + tv = SVisitor() + params = [ + Param(LocalId("a"), TensorType(IntType(32), ShapeSeq([]))), + Param(LocalId("b"), TensorType(IntType(32), ShapeSeq([]))) + ] + ret_type = TensorType(IntType(32), ShapeSeq([])) + body = IntLit(12) + tv.run(Function(params, ret_type, body)) + +def test_visit_call(): + tv = SVisitor() + fn = Function([Param(LocalId("_"), ProductType([]))], ProductType([]), ProductLit([])) + args = [ProductLit([])] + call = Call(fn, args) + tv.run(call) + +# def test_visit_debug(): +# tv = SVisitor() +# debug = Debug() +# tv.run(debug) + +def test_visit_unary_op(): + tv = SVisitor() + unary_op = UnaryOp(ir.UOp.SQ, IntLit(12)) + tv.run(unary_op) + +def test_visit_binary_op(): + tv = SVisitor() + binary_op = BinaryOp(ir.BOp.SUB, IntLit(15), IntLit(3)) + tv.run(binary_op) + +def test_visit_reverse(): + tv = SVisitor() + tv.run(Reverse(IntLit(1))) + +def test_visit_gradient(): + tv = SVisitor() + tv.run(Gradient(FloatLit(2.0))) + +def test_visit_zero(): + tv = SVisitor() + tv.run(Zero(TensorType(IntType(32), ShapeSeq([])))) + +def test_visit_projection(): + tv = SVisitor() + tv.run(Projection(ProductLit([IntLit(3), BoolLit(False)]), 1)) + +def test_visit_if(): + tv = SVisitor() + tv.run(If(BoolLit(True), IntLit(3), IntLit(9))) + +def test_visit_let(): + tv = SVisitor() + ident = LocalId("x") + tv.run(Let(ident, TensorType(IntType(32), ShapeSeq([])), IntLit(12), ident)) + +def test_visit_ref(): + tv = SVisitor() + tv.run(Ref(ProductLit([]))) + +def test_visit_val_ref(): + tv = SVisitor() + tv.run(ValRef(Ref(IntLit(12)))) + +def test_visit_set_ref(): + tv = SVisitor() + tv.run(SetRef( + Ref(ProductLit([IntLit(5), IntLit(7)])), + ProductLit([IntLit(6), IntLit(6)]) + ))