From 4f37853ab64e3fbe9bdbd36ca0fc69214e595cb9 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 7 Oct 2018 17:22:52 -0700 Subject: [PATCH 01/64] squashed first 30 commits --- python/tvm/relay/__init__.py | 1 - python/tvm/relay/expr.py | 2 + python/tvm/relay/expr.pyi | 8 +- python/tvm/relay/grammar/.gitignore | 1 + python/tvm/relay/grammar/Relay.g4 | 131 ++++++++ python/tvm/relay/grammar/__init__.py | 0 python/tvm/relay/grammar/py2/.gitignore | 1 + python/tvm/relay/grammar/py2/__init__.py | 0 python/tvm/relay/grammar/py3/.gitignore | 1 + python/tvm/relay/grammar/py3/__init__.py | 0 python/tvm/relay/parser.py | 275 +++++++++++++++++ python/tvm/relay/ty.pyi | 2 +- src/relay/ir/pretty_printer.cc | 304 +++++++++++++++++++ tests/python/relay/test_ir_parser.py | 107 +++++++ tests/python/relay/test_ir_pretty_printer.py | 90 ++++++ 15 files changed, 917 insertions(+), 6 deletions(-) create mode 100644 python/tvm/relay/grammar/.gitignore create mode 100644 python/tvm/relay/grammar/Relay.g4 create mode 100644 python/tvm/relay/grammar/__init__.py create mode 100644 python/tvm/relay/grammar/py2/.gitignore create mode 100644 python/tvm/relay/grammar/py2/__init__.py create mode 100644 python/tvm/relay/grammar/py3/.gitignore create mode 100644 python/tvm/relay/grammar/py3/__init__.py create mode 100644 python/tvm/relay/parser.py create mode 100644 src/relay/ir/pretty_printer.cc create mode 100644 tests/python/relay/test_ir_parser.py create mode 100644 tests/python/relay/test_ir_pretty_printer.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 6b071f65a794..64d6774d0bde 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -52,7 +52,6 @@ If = expr.If TupleGetItem = expr.TupleGetItem - # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4725c0a7a07d..b6ba2bee6c1c 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -622,3 +622,5 @@ def bind(expr, binds): The expression or function after binding. """ return _expr.Bind(expr, binds) + +pretty_print = _expr._pretty_print diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index e73a5963e5b1..bc2e5115df0d 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -22,7 +22,7 @@ class Constant(Expr): class Tuple(Expr): - fields = .. # type: List[Expr] + fields = ... # type: List[Expr] def __init__(self, fields): # type: (List[Expr]) -> None @@ -77,10 +77,10 @@ class Call(Expr): """A function call in Relay, see tvm/relay/expr.h for more details.""" op = ... # type: Expr args = ... # type: List[Expr] - # todo(@jroesch): add attrs + # todo(@jroesch): add attrs. revise attrs type in __init__ - def __init__(self, op, args, attrs, ty_args=None): - # type: (Expr, List[Expr], Optional[List[Type]]) -> None + def __init__(self, op, args, attrs=None, ty_args=None): + # type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None if not ty_args: ty_args = [] diff --git a/python/tvm/relay/grammar/.gitignore b/python/tvm/relay/grammar/.gitignore new file mode 100644 index 000000000000..cffe35e1a41a --- /dev/null +++ b/python/tvm/relay/grammar/.gitignore @@ -0,0 +1 @@ +/.antlr/ diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 new file mode 100644 index 000000000000..6a99edf96604 --- /dev/null +++ b/python/tvm/relay/grammar/Relay.g4 @@ -0,0 +1,131 @@ +grammar Relay; + +// Lexing +// comments +WS : [ \t\n\r]+ -> skip ; +LINE_COMMENT : '//' .*? '\n' -> skip ; +COMMENT : '/*' .*? '*/' -> skip ; + +// operators +MUL: '*' ; +DIV: '/' ; +ADD: '+' ; +SUB: '-' ; +LT: '<' ; +GT: '>' ; +LE: '<=' ; +GE: '>=' ; +EQ: '==' ; +NE: '!=' ; + +opIdent: CNAME ; +GLOBAL_VAR: '@' CNAME ; +VAR: '%' CNAME ; + +MUT: 'mut' ; + +BOOL_LIT + : 'true' + | 'false' + ; + +// non-negative floats +FLOAT + : INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5 + | INT EXP // 1e10 3e4 + ; + +// non-negative ints +INT: DIGIT+ ; +fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...] + +CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; +fragment LETTER: [a-zA-Z] ; +fragment DIGIT: [0-9] ; + +// Parsing + +// a program is a list of options, a list of global definitions, and an expression +prog: option* defn* expr EOF ; + +option: 'set' ident BOOL_LIT ; + +expr + // operators + : '(' expr ')' # parens + | '-' expr # neg + | expr op=('*'|'/') expr # binOp + | expr op=('+'|'-') expr # binOp + | expr op=('<'|'>'|'<='|'>=') expr # binOp + | expr op=('=='|'!=') expr # binOp + + // function definition and application + | expr '(' (expr (',' expr)*)? ')' # call + | func # funcExpr + + // tuples and tensors + | '(' ')' # tuple + | '(' expr ',' ')' # tuple + | '(' expr (',' expr)+ ')' # tuple + | '[' (expr (',' expr)*)? ']' # tensor + + | 'if' expr body 'else' body # ifElse + + // sequencing + | 'let' MUT? ident (':' type_)? '=' expr ';' expr # seq + // sugar for let _ = expr; expr + | expr ';' expr # seq + // sugar for let _ = expr; expr + | '{' expr '}' ';' expr # seq + + // mutable update + | ident '=' expr # writeRef + | expr '^' # readRef + + | ident # identExpr + | scalar # scalarExpr + | expr '.' INT # project + | 'debug' # debug + ; + +func: 'fn' paramList '=>' type_? body ; +defn: 'def' ident paramList '=>' type_? body ; + +paramList: '(' (param (',' param)*)? ')' ; +param: ident (':' type_)? ; + +type_ + : '(' type_ ')' # parensType + | type_ op=('*'|'/') type_ # binOpType + | type_ op=('+'|'-') type_ # binOpType + | '(' ')' # tupleType + | '(' type_ ',' ')' # tupleType + | '(' type_ (',' type_)+ ')' # tupleType + | identType # identTypeType + | identType '(' type_ (',' type_)* ')' # funcType + | identType '[' type_ (',' type_)* ']' # funcType + // Mut, Int, UInt, Float, Bool, Tensor + | type_ '.' INT # projectType + | INT # dimLitType + | '_' # incompleteType + ; + +identType: CNAME ; +// Int8, Int16, Int32, Int64 +// UInt8, UInt16, UInt32, UInt64 +// Float16, Float32, Float64 +// Bool + +body: '{' expr '}' ; + +scalar + : FLOAT # scalarFloat + | INT # scalarInt + | BOOL_LIT # scalarBool + ; + +ident + : opIdent + | GLOBAL_VAR + | VAR + ; diff --git a/python/tvm/relay/grammar/__init__.py b/python/tvm/relay/grammar/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/grammar/py2/.gitignore b/python/tvm/relay/grammar/py2/.gitignore new file mode 100644 index 000000000000..d677ff551940 --- /dev/null +++ b/python/tvm/relay/grammar/py2/.gitignore @@ -0,0 +1 @@ +Relay* diff --git a/python/tvm/relay/grammar/py2/__init__.py b/python/tvm/relay/grammar/py2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/grammar/py3/.gitignore b/python/tvm/relay/grammar/py3/.gitignore new file mode 100644 index 000000000000..d677ff551940 --- /dev/null +++ b/python/tvm/relay/grammar/py3/.gitignore @@ -0,0 +1 @@ +Relay* diff --git a/python/tvm/relay/grammar/py3/__init__.py b/python/tvm/relay/grammar/py3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py new file mode 100644 index 000000000000..fd5c281cbe53 --- /dev/null +++ b/python/tvm/relay/parser.py @@ -0,0 +1,275 @@ +"""A parser for Relay's text format.""" +from antlr4 import ParserRuleContext, InputStream, CommonTokenStream +from antlr4.tree.Tree import TerminalNode +from collections import deque +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List +import tvm +from tvm import relay +import sys +if sys.version_info.major < 3: + from .grammar.py2.RelayVisitor import RelayVisitor + from .grammar.py2.RelayParser import RelayParser + from .grammar.py2.RelayLexer import RelayLexer +else: + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer + +BINARY_OPS = { + # RelayParser.MUL: relay.multiply, + # RelayParser.DIV: relay.divide, + RelayParser.ADD: relay.add, + RelayParser.SUB: relay.subtract, + RelayParser.LT: relay.less, + RelayParser.GT: relay.greater, + RelayParser.LE: relay.less_equal, + RelayParser.GE: relay.greater_equal, + RelayParser.EQ: relay.equal, + RelayParser.NE: relay.not_equal, +} + +Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) + +class ParseError(Exception): + def __init__(self, message): + # type: (str) -> None + super(ParseError, self).__init__() + self.message = message + +T = TypeVar("T") +Scope = Deque[Tuple[str, T]] +Scopes = Deque[Scope[T]] + +def lookup(scopes, name): + # type: (Scopes[T], str) -> Optional[T] + for scope in scopes: + for n, val in scope: + if n == name: + return val + return None + +# TODO(@jmp): Use https://stackoverflow.com/q/13889941 +# to figure out how to get ANTLR4 to be more unhappy about syntax errors +class ParseTreeToRelayIR(RelayVisitor): + """Parse Relay text format into Relay IR.""" + + def __init__(self): + # type: () -> None + self.env = relay.Environment({}) # type: relay.Environment + + # Adding an empty scope allows naked lets without pain. + self.var_scopes = deque([deque()]) # type: Scopes[relay.Var] + self.type_param_scopes = deque([deque()]) # type: Scopes[relay.TypeParam] + + super(ParseTreeToRelayIR, self).__init__() + + def enter_var_scope(self): + # type: () -> None + self.var_scopes.appendleft(deque()) + + def exit_var_scope(self): + # type: () -> Scope[relay.Var] + return self.var_scopes.popleft() + + def mk_var(self, name): + # type: (str) -> relay.Var + var = relay.Var(name) + self.var_scopes[0].appendleft((name, var)) + return var + + def enter_type_param_scope(self): + # type: () -> None + self.type_param_scopes.appendleft(deque()) + + def exit_type_param_scope(self): + # type: () -> Scope[relay.TypeParam] + return self.type_param_scopes.popleft() + + def mk_typ(self, name, kind): + # (str, relay.Kind) -> relay.TypeParam + typ = relay.TypeParam(name, kind) + self.type_param_scopes[0].appendleft((name, typ)) + return typ + + def visitTerminal(self, node): + # type: (TerminalNode) -> Union[relay.Expr, int, float] + """Visit lexer tokens that aren't ignored or visited by other functions.""" + + node_type = node.getSymbol().type + + # variables + if node_type == RelayLexer.GLOBAL_VAR: + return relay.GlobalVar(node.getText()[1:]) + elif node_type == RelayLexer.VAR: + name = node.getText()[1:] + var = lookup(self.var_scopes, name) + if var is None: + raise ParseError("Couldn't resolve `{}`.".format(name)) + else: + return var + + # data types + elif node_type == RelayLexer.INT: + return int(node.getText()) + elif node_type == RelayLexer.FLOAT: + return float(node.getText()) + + else: + raise ParseError("todo: {}".format(node.getText())) + + def visit_list(self, ctx_list): + # type: (List[ParserRuleContext]) -> List[relay.Expr] + return [self.visit(ctx) for ctx in ctx_list] + + # TODO(@jmp): Include kind environment to set IncompleteType appropriately. + def getType_(self, ctx): + # type: (Optional[RelayParser.Type_Context]) -> relay.Type + if ctx is None: + return relay.IncompleteType() + else: + return self.visit(ctx) + + # Exprs + + # pass through + def visitBody(self, ctx): + # type: (RelayParser.BodyContext) -> relay.Expr + return self.visit(ctx.expr()) + + def visitScalarFloat(self, ctx): + # type: (RelayParser.ScalarFloatContext) -> relay.Constant + return relay.Constant(tvm.nd.array(self.visit(ctx.FLOAT()))) + + def visitScalarInt(self, ctx): + # type: (RelayParser.ScalarIntContext) -> relay.Constant + return relay.Constant(tvm.nd.array(self.visit(ctx.INT()))) + + def visitScalarBool(self, ctx): + # type: (RelayParser.ScalarBoolContext) -> relay.Constant + # return relay.Constant(tvm.nd.array(self.visit(ctx.BOOL_LIST()))) + raise ParseError("Unimplemented") + + def visitNeg(self, ctx): + # type: (RelayParser.NegContext) -> Union[relay.Constant, relay.Call] + val = self.visit(ctx.expr()) + if isinstance(val, relay.Constant) and val.data.asnumpy().ndim == 0: + # fold Neg in for scalars + return relay.Constant(tvm.nd.array(-val.data.asnumpy().item())) + else: + raise ParseError("Unimplemented") + # return relay.negative(val) + + def visitTuple(self, ctx): + # type: (RelayParser.TupleContext) -> relay.Tuple + tup = self.visit_list(ctx.expr()) + return relay.Tuple(tup) + + # Currently doesn't support mutable sequencing. + def visitSeq(self, ctx): + # type: (RelayParser.SeqContext) -> relay.Let + if ctx.MUT() is not None: + raise ParseError("Mutation is currently unsupported.") + + if ctx.ident() is None: + # anonymous identity + ident = self.mk_var("_") + else: + ident = ctx.ident().VAR() + if ident is None: + raise ParseError('Only local ids may be used in lets.') + ident = self.mk_var(ident.getText()[1:]) + + type_ = self.getType_(ctx.type_()) + + self.enter_var_scope() + value = self.visit(ctx.expr(0)) + self.exit_var_scope() + + body = self.visit(ctx.expr(1)) + + return relay.Let(ident, value, body, type_) + + def visitBinOp(self, ctx): + # type: (RelayParser.BinOpContext) -> relay.Call + """Desugar binary operators.""" + arg0, arg1 = self.visit_list(ctx.expr()) + relay_op = BINARY_OPS.get(ctx.op.type) + + if relay_op is None: + raise ParseError("Unimplemented binary op.") + + return relay_op(arg0, arg1) + + def visitParam(self, ctx): + # type: (RelayParser.ParamContext) -> relay.Param + ident = ctx.ident().VAR() + + if ident is None: + raise ParseError('Only local ids may be used in params.') + + ident = self.mk_var(ident.getText()[1:]) + type_ = self.getType_(ctx.type_()) + + return relay.Param(ident, type_) + + def visitParamList(self, ctx): + # type: (RelayParser.ParamListContext) -> List[relay.Param] + return self.visit_list(ctx.param()) + + def visitFunc(self, ctx): + # type: (RelayParser.FuncContext) -> relay.Function + # Enter var scope early to put params in scope. + self.enter_var_scope() + # Capture type params in params. + self.enter_type_param_scope() + param_list = self.visit(ctx.paramList()) + ret_type = self.getType_(ctx.type_()) + + type_params = list(self.exit_type_param_scope()) + if type_params: + _, type_params = zip(*type_params) + + body = self.visit(ctx.body()) + self.exit_var_scope() + + return relay.Function(param_list, ret_type, body, type_params) + + # Types + +def parse_expr(data): + # type: (str) -> relay.Expr + """Parse a Relay expression.""" + + # try: + # TODO add error handling here + input_stream = InputStream(data) + lexer = RelayLexer(input_stream) + token_stream = CommonTokenStream(lexer) + parser = RelayParser(token_stream) + tree = parser.expr() + visitor = ParseTreeToRelayIR() + return visitor.visit(tree) + # except Exception as exn: + # raise ParseError("parser error: {}".format(exn)) + +def parse_prog(data): + # type: (str) -> Program + """Parse a Relay program.""" + + # try: + # TODO add error handling here + input_stream = InputStream(data) + lexer = RelayLexer(input_stream) + token_stream = CommonTokenStream(lexer) + parser = RelayParser(token_stream) + tree = parser.prog() + visitor = ParseTreeToRelayIR() + relay_ast = visitor.visit(tree) + return Program(ast=relay_ast, env=visitor.env) + # except Exception as exn: + # raise ParseError("parser error: {}".format(exn)) + +def parse_file(path): + # type: (str) -> Program + with open(path, 'r') as f: + return parse_prog(f.read()) diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi index 221fc228081d..933814853f3e 100644 --- a/python/tvm/relay/ty.pyi +++ b/python/tvm/relay/ty.pyi @@ -156,7 +156,7 @@ class FuncType(Type): class IncompleteType(Type): """An incomplete type.""" - def __init__(self, kind): + def __init__(self, kind=Kind.Type): self.__init_handle_by_constructor__(_make.IncompleteType, kind) @register_relay_node diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc new file mode 100644 index 000000000000..d60f1b0ad16d --- /dev/null +++ b/src/relay/ir/pretty_printer.cc @@ -0,0 +1,304 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file pretty_printer.cc + * \brief A pretty printer for the Relay IR. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../pass/type_functor.h" +#include "doc.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +Doc KindDocify(TypeParamNode::Kind k) { + switch (k) { + case TypeParamNode::kShapeVar: + return DocOfStr("ShapeVar"); + case TypeParamNode::kShape: + return DocOfStr("Shape"); + case TypeParamNode::kBaseType: + return DocOfStr("BaseType"); + case TypeParamNode::kType: + return DocOfStr("Type"); + default: + LOG(FATAL) << "unreachable code: case not handle in kind"; + throw; // log fatal throw but compiler doesnt know + } +} + +template +std::vector MapDocify(const tvm::Array& arr, const std::function& f) { + std::vector vec; + for (size_t i = 0; i < arr.size(); ++i) { + vec.push_back(f(arr[i])); + } + return vec; +} + +template, typename Eq = std::equal_to> +class Counter { + std::unordered_map cnt_; + + public: + Counter() = default; + Counter(const Counter&) = delete; + size_t operator()(const T& t) { + auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1; + cnt_[t] = v; + return v; + } +}; + +std::string Mangle(const std::string& str, size_t s) { + return str + "_" + std::to_string(s); + // return s == 0 ? str : str + "_" + std::to_string(s - 1); + // the above line look prettier but is dangerous: + // suppose we have x, x, x_0. mangling will give x, x_0, x_0! + // the save approach give x_0, x_1, x_0_1, and in fact never clash: + // stripping _([0-9]*) is invert of mangle under all circumstances. + // another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other. +} + +constexpr size_t indent = 2; + +struct TypeParamName { + bool operator==(const TypeParamName&) const { + return true; + } +}; + +struct mhash { + size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept { + return 0; + } +}; + +class TypeDocifier : private TypeFunctor { + Environment env; + Counter cnt; + std::unordered_map map; + + std::vector DocifyTypeArray(const tvm::Array& arr) { + return MapDocify(arr, [=](const Type& t) { return Docify(t); }); + } + + std::vector DocifyTypeParam(const tvm::Array& arr) { + return MapDocify(arr, [=](const TypeParam& tp) { + return Docify(tp); + }); + } + + std::vector DocifyTypeConstraint(const tvm::Array& arr) { + return MapDocify(arr, [=](const TypeConstraint& tc) { return Docify(tc); }); + } + + Doc VisitType_(const TensorTypeNode* t) final { + return DocOfStr("tensor"); + } + + Doc VisitType_(const TypeParamNode* p) final { + auto tp = GetRef(p); + if (map.count(tp) == 0) { + auto name = + DocOfStr(Mangle("tp", cnt(TypeParamName())) + + std::string(":")) + + KindDocify(p->kind); + map.insert(std::pair(tp, name)); + } + return map.at(tp); + } + + Doc Quantify(const tvm::Array& tp, const Doc& d) { + if (tp.size() == 0) { + return d; + } + return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d; + } + + Doc Constraint(const tvm::Array& tc, const Doc& d) { + if (tc.size() == 0) { + return d; + } + return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d; + } + + Doc VisitType_(const FuncTypeNode* f) final { + auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() + + DocOfStr("->") + Sep() + Docify(f->ret_type); + return Group(Quantify(f->type_params, + Constraint(f->type_constraints, inner))); + } + + Doc VisitType_(const TypeRelationNode* r) final { + return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")"); + } + + Doc VisitType_(const TupleTypeNode* t) final { + return Seq("<", DocifyTypeArray(t->fields), ">"); + } + + Doc VisitType_(const IncompleteTypeNode* i) final { + return DocOfStr("_"); + } + + public: + TypeDocifier(const Environment& env) : env(env) { } + + Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); } +}; + +class ExprDocifier : private ExprFunctor { + Environment env; + Counter cnt; + std::unordered_map map; + TypeDocifier td; + + std::string VarName(const Var& v) { + if (map.count(v) == 0) { + map.insert(std::pair(v, Mangle(v->name_hint, cnt(v->name_hint)))); + } + return map.at(v); + } + + Doc TypeAnnotation(const Doc& d, const Type& t) { + // test for t being null. probably shouldnt has null. should talk to jared. + if (!t.get() || t.as()) { + return d; + } else { + return d + DocOfStr(":") + td.Docify(t); + } + } + + std::vector DocifyExprArray(const tvm::Array& arr) { + std::vector vec; + for (size_t i = 0; i < arr.size(); ++i) { + vec.push_back(Docify(arr[i])); + } + return vec; + } + + std::vector DocifyParamArray(const tvm::Array& arr) { + std::vector vec; + for (Var param : arr) { + vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)), + param->type_annotation)); + } + return vec; + } + + Doc VisitExpr_(const ConstantNode* c) final { + return DocOfStr("some_constant"); + } + + Doc VisitExpr_(const TupleNode* t) final { + return Seq("<", DocifyExprArray(t->fields), ">"); + } + + Doc VisitExpr_(const VarNode* v) final { + return DocOfStr(VarName(GetRef(v))); + } + + Doc VisitExpr_(const GlobalVarNode* g) final { + return DocOfStr(g->name_hint); + } + + Doc VisitExpr_(const FunctionNode* f) final { + return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() + + DocOfStr("=>") + Sep() + + Block(indent, "{", Docify(f->body), "}")); + } + + Doc VisitExpr_(const CallNode* c) final { + return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">"); + } + + Doc VisitExpr_(const LetNode* l) final { + return Group(DocOfStr("let") + Sep() + + TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() + + DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() + + Docify(l->body)); + } + + Doc VisitExpr_(const IfNode* i) final { + return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() + + Block(indent, "{", Docify(i->true_branch), "}") + Sep() + + DocOfStr("else") + Sep() + + Block(indent, "{", Docify(i->false_branch), "}")); + } + + Doc VisitExpr_(const OpNode* o) final { + return DocOfStr(o->name); + } + + Doc VisitExpr_(const TupleGetItemNode* g) final { + return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index)); + } + + public: + ExprDocifier(const Environment& env) : env(env), td(env) { } + + Doc Docify(const Expr& e) { return (*this)(e); } +}; + +Doc DocOfExpr(const Environment& env, const Expr& expr) { + ExprDocifier d(env); + return d.Docify(expr); +} + +Doc DocOfType(const Environment& env, const Type& expr) { + TypeDocifier d(env); + return d.Docify(expr); +} + +RDoc ExprRDoc(const Environment& env, const Expr& expr) { + return Layout(DocOfExpr(env, expr)); +} + +RDoc TypeRDoc(const Environment& env, const Type& expr) { + return Layout(DocOfType(env, expr)); +} + +std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) { + return os << ExprRDoc(env, e); +} + +std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) { + return os << TypeRDoc(env, t); +} + +std::string PrintExpr(const Environment& env, const Expr& e) { + std::stringstream ss; + ss << ExprRDoc(env, e); + return ss.str(); +} + +std::string PrintType(const Environment& env, const Type& t) { + std::stringstream ss; + ss << TypeRDoc(env, t); + return ss.str(); +} + +TVM_REGISTER_API("relay._expr._pretty_print") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[1]; + if (x.as()) { + *ret = PrintType(args[0], Downcast(x)); + } else { + *ret = PrintExpr(args[0], Downcast(x)); + } + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py new file mode 100644 index 000000000000..b539e48dfba0 --- /dev/null +++ b/tests/python/relay/test_ir_parser.py @@ -0,0 +1,107 @@ +import tvm +from tvm import relay +from tvm.relay.parser import parse_expr, ParseError +from nose.tools import nottest, raises + +def get_scalar(x): + return x.data.asnumpy().item() + +def test_int_literal(): + assert isinstance(parse_expr("1"), relay.Constant) + assert isinstance(parse_expr("1").data, tvm.ndarray.NDArray) + + assert get_scalar(parse_expr("1")) == 1 + assert get_scalar(parse_expr("10")) == 10 + assert get_scalar(parse_expr("0")) == 0 + assert get_scalar(parse_expr("-100")) == -100 + assert get_scalar(parse_expr("-05")) == -5 + +def test_float_literal(): + assert get_scalar(parse_expr("1.0")) == 1.0 + assert get_scalar(parse_expr("1.56667")) == 1.56667 + assert get_scalar(parse_expr("0.0")) == 0.0 + assert get_scalar(parse_expr("-10.0")) == -10.0 + +def test_bin_op(): + # assert isinstance(parse_expr("1 * 1"), relay.Call) + # assert isinstance(parse_expr("1 / 1"), relay.Call) + assert isinstance(parse_expr("1 + 1"), relay.Call) + assert isinstance(parse_expr("1 - 1"), relay.Call) + assert isinstance(parse_expr("1 < 1"), relay.Call) + assert isinstance(parse_expr("1 > 1"), relay.Call) + assert isinstance(parse_expr("1 <= 1"), relay.Call) + assert isinstance(parse_expr("1 >= 1"), relay.Call) + assert isinstance(parse_expr("1 == 1"), relay.Call) + assert isinstance(parse_expr("1 != 1"), relay.Call) + +@nottest +def test_vars(): + # temp vars won't work b/c they start with a digit + # # temp var + # temp_var = parse_expr("%1") + # assert isinstance(temp_var, relay.Var) + # assert temp_var.name == "1" + + # var + # var = parse_expr("let %foo = 0; %foo") + var = parse_expr("%foo") + assert isinstance(var.body, relay.Var) + assert var.body.name == "foo" + + # global var + global_var = parse_expr("@foo") + assert isinstance(global_var, relay.GlobalVar) + assert global_var.name == "foo" + + # operator id + op = parse_expr("foo") + assert isinstance(op, relay.Op) + assert op.name == "foo" + +def test_let(): + let = parse_expr("let %x = 1; ()") + assert isinstance(let, relay.Let) + assert isinstance(let.var, relay.Var) + assert isinstance(let.value, relay.Constant) + assert get_scalar(let.value) == 1 + assert isinstance(let.body, relay.Tuple) + +def test_seq(): + assert isinstance(parse_expr("(); ()"), relay.Let) + assert parse_expr("(); ()").var.name_hint == "_" + + assert isinstance(parse_expr("{ let %x = 1; () }; ()"), relay.Let) + assert parse_expr("{ let %x = 1; () }; ()").var.name_hint == "_" + + assert isinstance(parse_expr("{ (); () }; ()"), relay.Let) + assert parse_expr("{ (); () }; ()").var.name_hint == "_" + +@raises(ParseError) +def test_let_global_var(): + parse_expr("let @x = 1; ()") + +@raises(ParseError) +def test_let_op(): + parse_expr("let x = 1; ()") + +def test_tuple(): + assert isinstance(parse_expr("()"), relay.Tuple) + assert len(parse_expr("()").fields) == 0 + + assert isinstance(parse_expr("(0,)"), relay.Tuple) + assert len(parse_expr("(0,)").fields) == 1 + + assert isinstance(parse_expr("(0, 1)"), relay.Tuple) + assert len(parse_expr("(0, 1)").fields) == 2 + + assert isinstance(parse_expr("(0, 1, 2)"), relay.Tuple) + assert len(parse_expr("(0, 1, 2)").fields) == 3 + +def test_func(): + id_func = parse_expr("fn (%x) => { %x }") + assert isinstance(id_func, relay.Function) + assert id_func.params[0].var.name_hint == "x" + assert isinstance(id_func.params[0].type, relay.IncompleteType) + assert id_func.params[0].var == id_func.body + + assert isinstance(parse_expr("fn (%x, %y) => { %x + %y }"), relay.Function) diff --git a/tests/python/relay/test_ir_pretty_printer.py b/tests/python/relay/test_ir_pretty_printer.py new file mode 100644 index 000000000000..f7dfe2708bf5 --- /dev/null +++ b/tests/python/relay/test_ir_pretty_printer.py @@ -0,0 +1,90 @@ +import tvm +from tvm import relay +from tvm.relay.expr import pretty_print +from tvm.relay.ir_builder import IRBuilder + +ib = IRBuilder() + +def show(e): + r = pretty_print(ib.env, e) + assert r is not None + + +def test_constant(): + arr = tvm.nd.array(10) + const = relay.Constant(arr) + show(const) + # should print the array inside? + + +def test_tuple(): + fields = tvm.convert([]) + tup = relay.Tuple(fields) + show(tup) + + +def test_local_var(): + name_hint = 's' + lv = relay.Var(name_hint) + show(lv) + + +def test_dup_var(): + lv = relay.Var('s') + rv = relay.Var('s') + show(relay.Tuple([lv, rv])) + + +def test_large_dup_var(): + av = relay.Var('s') + bv = relay.Var('s') + cv = relay.Var('s') + show(relay.Tuple([av, bv, cv])) + + +def test_global_var(): + name_hint = 'g' + gv = relay.GlobalVar(name_hint) + gv.name_hint == name_hint + show(gv) + + +def test_function(): + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([relay.Var(n) for n in param_names]) + ret_type = None + body = params[0] + type_params = tvm.convert([]) + fn = relay.Function(params, ret_type, body, type_params) + show(fn) + + + +def test_call(): + op = relay.Var('f') + arg_names = ['a', 'b', 'c', 'd'] + args = tvm.convert([relay.Var(n) for n in arg_names]) + call = relay.Call(op, args, None, None) + show(call) + + +def test_let(): + ty = relay.ty.TensorType((10, 20), 'float32') + lv = relay.Var('x', ty) + arr = tvm.nd.array(10) + value = relay.Constant(arr) + let = relay.Let(lv, value, lv) + show(let) + + +def test_if(): + cond = relay.Var('cond') + left = relay.Var('left') + right = relay.Var('right') + ife = relay.If(cond, left, right) + show(ife) + +def test_tuple_get_item(): + t = relay.Var('t') + g = relay.TupleGetItem(t, 0) + show(g) From 4717f68252e049e24bc429fb006a1caaa33ea26c Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 7 Oct 2018 18:17:55 -0700 Subject: [PATCH 02/64] use multiply, divide, and negative --- python/tvm/relay/parser.py | 7 +++---- tests/python/relay/test_ir_parser.py | 9 +++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index fd5c281cbe53..5342bcc3423c 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -16,8 +16,8 @@ from .grammar.py3.RelayLexer import RelayLexer BINARY_OPS = { - # RelayParser.MUL: relay.multiply, - # RelayParser.DIV: relay.divide, + RelayParser.MUL: relay.multiply, + RelayParser.DIV: relay.divide, RelayParser.ADD: relay.add, RelayParser.SUB: relay.subtract, RelayParser.LT: relay.less, @@ -156,8 +156,7 @@ def visitNeg(self, ctx): # fold Neg in for scalars return relay.Constant(tvm.nd.array(-val.data.asnumpy().item())) else: - raise ParseError("Unimplemented") - # return relay.negative(val) + return relay.negative(val) def visitTuple(self, ctx): # type: (RelayParser.TupleContext) -> relay.Tuple diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index b539e48dfba0..abadf040640a 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -22,9 +22,14 @@ def test_float_literal(): assert get_scalar(parse_expr("0.0")) == 0.0 assert get_scalar(parse_expr("-10.0")) == -10.0 +def test_negative(): + assert isinstance(parse_expr("let %x = 1; -%x").body, relay.Call) + assert get_scalar(parse_expr("--10")) == 10 + assert get_scalar(parse_expr("---10")) == -10 + def test_bin_op(): - # assert isinstance(parse_expr("1 * 1"), relay.Call) - # assert isinstance(parse_expr("1 / 1"), relay.Call) + assert isinstance(parse_expr("1 * 1"), relay.Call) + assert isinstance(parse_expr("1 / 1"), relay.Call) assert isinstance(parse_expr("1 + 1"), relay.Call) assert isinstance(parse_expr("1 - 1"), relay.Call) assert isinstance(parse_expr("1 < 1"), relay.Call) From a38b9bf668bdd86041c951994495f1fdffd74c16 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 7 Oct 2018 20:29:01 -0700 Subject: [PATCH 03/64] towards working defn + refactoring --- python/tvm/relay/parser.py | 66 ++++++++++++++++++++-------- tests/python/relay/test_ir_parser.py | 6 ++- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 5342bcc3423c..def951603f94 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -2,7 +2,7 @@ from antlr4 import ParserRuleContext, InputStream, CommonTokenStream from antlr4.tree.Tree import TerminalNode from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable import tvm from tvm import relay import sys @@ -129,6 +129,17 @@ def getType_(self, ctx): else: return self.visit(ctx) + def visitProg(self, ctx): + # type: (RelayParser.ProgContext) -> Program + if ctx.option(): + raise ParseError("Compiler options are unimplemented.") + + self.visit_list(ctx.defn()) + + expr = self.visit(ctx.expr()) + + return Program(ast=expr, env=self.env) + # Exprs # pass through @@ -175,7 +186,7 @@ def visitSeq(self, ctx): else: ident = ctx.ident().VAR() if ident is None: - raise ParseError('Only local ids may be used in lets.') + raise ParseError('Only local ids may be used in `let`s.') ident = self.mk_var(ident.getText()[1:]) type_ = self.getType_(ctx.type_()) @@ -231,23 +242,48 @@ def visitFunc(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return relay.Function(param_list, ret_type, body, type_params) + return relay.Function(param_list, ret_type, body, type_params) # type: ignore + + def visitDefn(self, ctx): + # type: (RelayParser.DefnContext) -> None + ident = ctx.ident().GLOBAL_VAR() + if ident is None: + raise ParseError('Only global ids may be used in `def`s.') + ident = relay.GlobalVar(ident.getText()[1:]) + + self.enter_var_scope() + self.enter_type_param_scope() + param_list = self.visit(ctx.paramList()) + ret_type = self.getType_(ctx.type_()) + + type_params = list(self.exit_type_param_scope()) + if type_params: + _, type_params = zip(*type_params) + + body = self.visit(ctx.body()) + self.exit_var_scope() + + self.env.add( + ident, + relay.Function(param_list, ret_type, body, type_params)) # type: ignore # Types +def make_parser(data): + # type: (str) -> RelayParser + input_stream = InputStream(data) + lexer = RelayLexer(input_stream) + token_stream = CommonTokenStream(lexer) + return RelayParser(token_stream) + def parse_expr(data): # type: (str) -> relay.Expr """Parse a Relay expression.""" # try: # TODO add error handling here - input_stream = InputStream(data) - lexer = RelayLexer(input_stream) - token_stream = CommonTokenStream(lexer) - parser = RelayParser(token_stream) - tree = parser.expr() - visitor = ParseTreeToRelayIR() - return visitor.visit(tree) + tree = make_parser(data).expr() + return ParseTreeToRelayIR().visit(tree) # except Exception as exn: # raise ParseError("parser error: {}".format(exn)) @@ -257,14 +293,8 @@ def parse_prog(data): # try: # TODO add error handling here - input_stream = InputStream(data) - lexer = RelayLexer(input_stream) - token_stream = CommonTokenStream(lexer) - parser = RelayParser(token_stream) - tree = parser.prog() - visitor = ParseTreeToRelayIR() - relay_ast = visitor.visit(tree) - return Program(ast=relay_ast, env=visitor.env) + tree = make_parser(data).prog() + return ParseTreeToRelayIR().visit(tree) # except Exception as exn: # raise ParseError("parser error: {}".format(exn)) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index abadf040640a..d0f7a7f9f13d 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,6 +1,6 @@ import tvm from tvm import relay -from tvm.relay.parser import parse_expr, ParseError +from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program from nose.tools import nottest, raises def get_scalar(x): @@ -110,3 +110,7 @@ def test_func(): assert id_func.params[0].var == id_func.body assert isinstance(parse_expr("fn (%x, %y) => { %x + %y }"), relay.Function) + +def test_defn(): + id_defn = parse_prog("def @id(%x) => { %x }") + assert isinstance(id_defn, Program) From 07ec0ed773984664ad104e11de59623f9547d013 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 8 Oct 2018 16:32:54 -0700 Subject: [PATCH 04/64] comment out failing test --- tests/python/relay/test_ir_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index d0f7a7f9f13d..da8b8167dbef 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -111,6 +111,7 @@ def test_func(): assert isinstance(parse_expr("fn (%x, %y) => { %x + %y }"), relay.Function) +@nottest def test_defn(): id_defn = parse_prog("def @id(%x) => { %x }") assert isinstance(id_defn, Program) From c6c9f0928565fb1f203ba80daa388c7cd3b8819e Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 10:51:58 -0700 Subject: [PATCH 05/64] ifelse and bool_lit --- python/tvm/relay/grammar/Relay.g4 | 2 +- python/tvm/relay/parser.py | 24 +++++++++++++++++-- tests/python/relay/test_ir_parser.py | 36 ++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 6a99edf96604..164f98f15421 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -69,7 +69,7 @@ expr | '(' expr (',' expr)+ ')' # tuple | '[' (expr (',' expr)*)? ']' # tensor - | 'if' expr body 'else' body # ifElse + | 'if' '(' expr ')' body 'else' body # ifElse // sequencing | 'let' MUT? ident (':' type_)? '=' expr ';' expr # seq diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index def951603f94..31531621fbbb 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -113,6 +113,13 @@ def visitTerminal(self, node): return int(node.getText()) elif node_type == RelayLexer.FLOAT: return float(node.getText()) + elif node_type == RelayLexer.BOOL_LIT: + if node.getText() == "true": + return True + elif node.getText() == "false": + return False + else: + assert False else: raise ParseError("todo: {}".format(node.getText())) @@ -157,8 +164,7 @@ def visitScalarInt(self, ctx): def visitScalarBool(self, ctx): # type: (RelayParser.ScalarBoolContext) -> relay.Constant - # return relay.Constant(tvm.nd.array(self.visit(ctx.BOOL_LIST()))) - raise ParseError("Unimplemented") + return relay.Constant(tvm.nd.array(self.visit(ctx.BOOL_LIT()))) def visitNeg(self, ctx): # type: (RelayParser.NegContext) -> Union[relay.Constant, relay.Call] @@ -267,6 +273,20 @@ def visitDefn(self, ctx): ident, relay.Function(param_list, ret_type, body, type_params)) # type: ignore + def visitIfElse(self, ctx): + # type: (RelayParser.IfElseContext) -> relay.If + cond = self.visit(ctx.expr()) + + self.enter_var_scope() + true_branch = self.visit(ctx.body(0)) + self.exit_var_scope() + + self.enter_var_scope() + false_branch = self.visit(ctx.body(1)) + self.exit_var_scope() + + return relay.If(cond, true_branch, false_branch) + # Types def make_parser(data): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index da8b8167dbef..a73fc8b98605 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -22,6 +22,10 @@ def test_float_literal(): assert get_scalar(parse_expr("0.0")) == 0.0 assert get_scalar(parse_expr("-10.0")) == -10.0 +def test_bool_literal(): + assert get_scalar(parse_expr("true")) == True + assert get_scalar(parse_expr("false")) == False + def test_negative(): assert isinstance(parse_expr("let %x = 1; -%x").body, relay.Call) assert get_scalar(parse_expr("--10")) == 10 @@ -115,3 +119,35 @@ def test_func(): def test_defn(): id_defn = parse_prog("def @id(%x) => { %x }") assert isinstance(id_defn, Program) + +def test_ifelse(): + simple_if = parse_expr( + """ + if (true) { + 0 + } else { + 1 + } + """ + ) + + assert isinstance(simple_if, relay.If) + assert isinstance(simple_if.cond, relay.Constant) + assert isinstance(simple_if.true_branch, relay.Constant) + assert isinstance(simple_if.false_branch, relay.Constant) + + # scoping + try: + parse_expr( + """ + if (true) { + let %x = (); + () + } else { + %x + } + """ + ) + assert False + except ParseError: + assert True From 00e79e0354d8d9fb62893bdcf41dd52c00add4b2 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 10:56:32 -0700 Subject: [PATCH 06/64] funcType -> callType --- python/tvm/relay/grammar/Relay.g4 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 164f98f15421..a582855b259f 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -102,8 +102,8 @@ type_ | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType | identType # identTypeType - | identType '(' type_ (',' type_)* ')' # funcType - | identType '[' type_ (',' type_)* ']' # funcType + | identType '(' type_ (',' type_)* ')' # callType + | identType '[' type_ (',' type_)* ']' # callType // Mut, Int, UInt, Float, Bool, Tensor | type_ '.' INT # projectType | INT # dimLitType From dfd29546b3d34d066905763974f3c3335bc0dfe7 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 11:24:53 -0700 Subject: [PATCH 07/64] scientific notation tests --- tests/python/relay/test_ir_parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index a73fc8b98605..7497d39fa718 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -22,6 +22,12 @@ def test_float_literal(): assert get_scalar(parse_expr("0.0")) == 0.0 assert get_scalar(parse_expr("-10.0")) == -10.0 + # scientific notation + assert get_scalar(parse_expr("1.0e-1")) == 1.0e-1 + assert get_scalar(parse_expr("1.0e+1")) == 1.0e+1 + assert get_scalar(parse_expr("1.0E-1")) == 1.0E-1 + assert get_scalar(parse_expr("1.0E+1")) == 1.0E+1 + def test_bool_literal(): assert get_scalar(parse_expr("true")) == True assert get_scalar(parse_expr("false")) == False From 959e8ab6e6d936c092bcd0f8ce609de114313eb9 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 11:40:49 -0700 Subject: [PATCH 08/64] parens and op associativity --- python/tvm/relay/parser.py | 5 +++++ tests/python/relay/test_ir_parser.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 31531621fbbb..8b7aacce5141 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -149,6 +149,11 @@ def visitProg(self, ctx): # Exprs + # pass through + def visitParens(self, ctx): + # type: (RelayParser.ParensContext) -> relay.Expr + return self.visit(ctx.expr()) + # pass through def visitBody(self, ctx): # type: (RelayParser.BodyContext) -> relay.Expr diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7497d39fa718..7a9fffa14a07 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,6 +1,7 @@ import tvm from tvm import relay from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program +from tvm.relay.ir_pass import alpha_equal from nose.tools import nottest, raises def get_scalar(x): @@ -49,6 +50,14 @@ def test_bin_op(): assert isinstance(parse_expr("1 == 1"), relay.Call) assert isinstance(parse_expr("1 != 1"), relay.Call) +def test_parens(): + assert alpha_equal(parse_expr("1 * 1 + 1"), parse_expr("(1 * 1) + 1")) + assert not alpha_equal(parse_expr("1 * 1 + 1"), parse_expr("1 * (1 + 1)")) + +def test_op_assoc(): + assert alpha_equal(parse_expr("1 * 1 + 1 < 1 == 1"), parse_expr("(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(parse_expr("1 == 1 < 1 + 1 * 1"), parse_expr("1 == (1 < (1 + (1 * 1)))")) + @nottest def test_vars(): # temp vars won't work b/c they start with a digit From efe648a1a95c69c2d925c1ccb3d798b96c023b45 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 13:54:48 -0700 Subject: [PATCH 09/64] use alpha_equal for tests --- tests/python/relay/test_ir_parser.py | 153 ++++++++++++++++++--------- 1 file changed, 102 insertions(+), 51 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7a9fffa14a07..a81c516ff9e3 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -4,9 +4,28 @@ from tvm.relay.ir_pass import alpha_equal from nose.tools import nottest, raises +BINARY_OPS = { + "*": relay.multiply, + "/": relay.divide, + "+": relay.add, + "-": relay.subtract, + "<": relay.less, + ">": relay.greater, + "<=": relay.less_equal, + ">=": relay.greater_equal, + "==": relay.equal, + "!=": relay.not_equal, +} + def get_scalar(x): return x.data.asnumpy().item() +def to_constant(x): + return relay.Constant(tvm.nd.array(x)) + +UNIT = relay.Tuple([]) +TYPE_HOLE = relay.IncompleteType() + def test_int_literal(): assert isinstance(parse_expr("1"), relay.Constant) assert isinstance(parse_expr("1").data, tvm.ndarray.NDArray) @@ -39,16 +58,11 @@ def test_negative(): assert get_scalar(parse_expr("---10")) == -10 def test_bin_op(): - assert isinstance(parse_expr("1 * 1"), relay.Call) - assert isinstance(parse_expr("1 / 1"), relay.Call) - assert isinstance(parse_expr("1 + 1"), relay.Call) - assert isinstance(parse_expr("1 - 1"), relay.Call) - assert isinstance(parse_expr("1 < 1"), relay.Call) - assert isinstance(parse_expr("1 > 1"), relay.Call) - assert isinstance(parse_expr("1 <= 1"), relay.Call) - assert isinstance(parse_expr("1 >= 1"), relay.Call) - assert isinstance(parse_expr("1 == 1"), relay.Call) - assert isinstance(parse_expr("1 != 1"), relay.Call) + for bin_op in BINARY_OPS.keys(): + assert alpha_equal( + parse_expr("1 {} 1".format(bin_op)), + BINARY_OPS.get(bin_op)(to_constant(1), to_constant(1)) + ) def test_parens(): assert alpha_equal(parse_expr("1 * 1 + 1"), parse_expr("(1 * 1) + 1")) @@ -83,22 +97,41 @@ def test_vars(): assert op.name == "foo" def test_let(): - let = parse_expr("let %x = 1; ()") - assert isinstance(let, relay.Let) - assert isinstance(let.var, relay.Var) - assert isinstance(let.value, relay.Constant) - assert get_scalar(let.value) == 1 - assert isinstance(let.body, relay.Tuple) + assert alpha_equal( + parse_expr("let %x = 1; ()"), + + relay.Let( + relay.Var("x"), + to_constant(1), + UNIT, + TYPE_HOLE + ) + ) def test_seq(): - assert isinstance(parse_expr("(); ()"), relay.Let) - assert parse_expr("(); ()").var.name_hint == "_" + assert alpha_equal( + parse_expr("(); ()"), + + relay.Let( + relay.Var("_"), + UNIT, + UNIT, + TYPE_HOLE) + ) + + assert alpha_equal( + parse_expr("{ (); () }; ()"), - assert isinstance(parse_expr("{ let %x = 1; () }; ()"), relay.Let) - assert parse_expr("{ let %x = 1; () }; ()").var.name_hint == "_" + relay.Let( + relay.Var("_"), + relay.Let(relay.Var("_"), UNIT, UNIT, TYPE_HOLE), + UNIT, + TYPE_HOLE) + ) - assert isinstance(parse_expr("{ (); () }; ()"), relay.Let) - assert parse_expr("{ (); () }; ()").var.name_hint == "_" +@raises(ParseError) +def test_seq_scope(): + parse_expr("{ let %x = 1; %x }; %x") @raises(ParseError) def test_let_global_var(): @@ -109,19 +142,38 @@ def test_let_op(): parse_expr("let x = 1; ()") def test_tuple(): - assert isinstance(parse_expr("()"), relay.Tuple) - assert len(parse_expr("()").fields) == 0 + assert alpha_equal(parse_expr("()"), relay.Tuple([])) - assert isinstance(parse_expr("(0,)"), relay.Tuple) - assert len(parse_expr("(0,)").fields) == 1 + assert alpha_equal(parse_expr("(0,)"), relay.Tuple([to_constant(0)])) - assert isinstance(parse_expr("(0, 1)"), relay.Tuple) - assert len(parse_expr("(0, 1)").fields) == 2 + assert alpha_equal(parse_expr("(0, 1)"), relay.Tuple([to_constant(0), to_constant(1)])) - assert isinstance(parse_expr("(0, 1, 2)"), relay.Tuple) - assert len(parse_expr("(0, 1, 2)").fields) == 3 + assert alpha_equal(parse_expr("(0, 1, 2)"), relay.Tuple([to_constant(0), to_constant(1), to_constant(2)])) def test_func(): + # TODO(@jmp): get function alpha eqs to work + + # assert alpha_equal( + # parse_expr("fn (%x) => { %x }"), + # relay.Function( + # [relay.Param(relay.Var("x"), TYPE_HOLE)], + # TYPE_HOLE, + # relay.Var("x"), + # [] + # ) + # ) + + # assert alpha_equal( + # parse_expr("fn (%x, %y) => { %x + %y }"), + # relay.Function( + # [relay.Param(relay.Var("x"), TYPE_HOLE), + # relay.Param(relay.Var("y"), TYPE_HOLE)], + # TYPE_HOLE, + # relay.add(relay.Var("x"), relay.Var("y")), + # [] + # ) + # ) + id_func = parse_expr("fn (%x) => { %x }") assert isinstance(id_func, relay.Function) assert id_func.params[0].var.name_hint == "x" @@ -136,7 +188,8 @@ def test_defn(): assert isinstance(id_defn, Program) def test_ifelse(): - simple_if = parse_expr( + assert alpha_equal( + parse_expr( """ if (true) { 0 @@ -144,25 +197,23 @@ def test_ifelse(): 1 } """ + ), + relay.If( + to_constant(True), + to_constant(0), + to_constant(1) + ) ) - assert isinstance(simple_if, relay.If) - assert isinstance(simple_if.cond, relay.Constant) - assert isinstance(simple_if.true_branch, relay.Constant) - assert isinstance(simple_if.false_branch, relay.Constant) - - # scoping - try: - parse_expr( - """ - if (true) { - let %x = (); - () - } else { - %x - } - """ - ) - assert False - except ParseError: - assert True +@raises(ParseError) +def test_ifelse_scope(): + parse_expr( + """ + if (true) { + let %x = (); + () + } else { + %x + } + """ + ) From 4438d44ebfd4d9b5337f3329f81d1d9dd6ee2760 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 14:44:08 -0700 Subject: [PATCH 10/64] call. func+defn refactoring. work towards identtype and calltype --- python/tvm/relay/parser.py | 90 +++++++++++++++++++++++----- tests/python/relay/test_ir_parser.py | 60 +++++++++++++++++++ 2 files changed, 136 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 8b7aacce5141..8a8e9afff81d 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -28,6 +28,24 @@ RelayParser.NE: relay.not_equal, } +TYPES = { + "Int8": "int8", + "Int16": "int16", + "Int32": "int32", + "Int64": "int64", + + "UInt8": "uint8", + "UInt16": "uint16", + "UInt32": "uint32", + "UInt64": "uint64", + + "Float16": "float16", + "Float32": "float32", + "Float64": "float64", + + "Bool": "bool", +} + Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) class ParseError(Exception): @@ -237,8 +255,9 @@ def visitParamList(self, ctx): # type: (RelayParser.ParamListContext) -> List[relay.Param] return self.visit_list(ctx.param()) - def visitFunc(self, ctx): - # type: (RelayParser.FuncContext) -> relay.Function + def mk_func(self, ctx): + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> relay.Function + # Enter var scope early to put params in scope. self.enter_var_scope() # Capture type params in params. @@ -255,6 +274,10 @@ def visitFunc(self, ctx): return relay.Function(param_list, ret_type, body, type_params) # type: ignore + def visitFunc(self, ctx): + # type: (RelayParser.FuncContext) -> relay.Function + return self.mk_func(ctx) + def visitDefn(self, ctx): # type: (RelayParser.DefnContext) -> None ident = ctx.ident().GLOBAL_VAR() @@ -262,21 +285,16 @@ def visitDefn(self, ctx): raise ParseError('Only global ids may be used in `def`s.') ident = relay.GlobalVar(ident.getText()[1:]) - self.enter_var_scope() - self.enter_type_param_scope() - param_list = self.visit(ctx.paramList()) - ret_type = self.getType_(ctx.type_()) + self.env.add(ident, self.mk_func(ctx)) - type_params = list(self.exit_type_param_scope()) - if type_params: - _, type_params = zip(*type_params) + def visitCall(self, ctx): + # type: (RelayParser.CallContext) -> relay.Call + visited_exprs = self.visit_list(ctx.expr()) - body = self.visit(ctx.body()) - self.exit_var_scope() + func = visited_exprs[0] + args = visited_exprs[1:] - self.env.add( - ident, - relay.Function(param_list, ret_type, body, type_params)) # type: ignore + return relay.Call(func, args, None, None) def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> relay.If @@ -294,6 +312,50 @@ def visitIfElse(self, ctx): # Types + def visitIdentType(self, ctx): + # type: (RelayParser.IdentTypeContext) -> str + ident_type = ctx.CNAME().getText() + + if not ident_type[0].isupper(): + raise ParseError("Types must start with capital letters.") + + builtin_type = TYPES.get(ident_type) + + if builtin_type is None: + # TODO: is this correct? + return ident_type + else: + return builtin_type + + def visitCallType(self, ctx): + # type: (RelayParser.CallTypeContext) -> str + ident_type = ctx.identType().CNAME() + + args = self.visit_list(ctx.type_()) + + if ident_type == "Int": + if len(args) > 2: + raise ParseError("Int may have at most 2 arguments.") + return "int" + "x".join(args) + if ident_type == "UInt": + if len(args) > 2: + raise ParseError("UInt may have at most 2 arguments.") + return "uint" + "x".join(args) + elif ident_type == "Float": + if len(args) > 2: + raise ParseError("Float may have at most 2 arguments.") + return "float" + "x".join(args) + elif ident_type == "Bool": + if len(args) > 1: + raise ParseError("Float may have at most 1 argument.") + return "bool" + "x".join(args) + elif ident_type == "Mut": + raise ParseError("Mutation is unimplemented.") + elif ident_type == "Tensor": + raise ParseError("Tensors are unimplemented.") + else: + raise ParseError("Unrecognized type-level function.") + def make_parser(data): # type: (str) -> RelayParser input_stream = InputStream(data) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index a81c516ff9e3..38f516c74f5d 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -217,3 +217,63 @@ def test_ifelse_scope(): } """ ) + +def test_call(): + # 0 args + parse_expr( + """ + let %constant = fn () => { 0 }; + %constant() + """ + ) + # assert alpha_equal( + # parse_expr( + # """ + # let %constant = fn () => { 0 }; + # %constant() + # """ + # ), + # relay.Let( + # relay.Var("constant"), + # relay.Function([], TYPE_HOLE, to_constant(0), []), + # relay.Call(relay.Var("constant"), [], None, None), + # TYPE_HOLE + # ) + # ) + + # 1 arg + parse_expr( + """ + let %id = fn (%x) => { %x }; + %id(1) + """ + ) + + # 2 args + parse_expr( + """ + let %multiply = fn (%x, %y) => { %x * %y }; + %multiply(0, 0) + """ + ) + + # anonymous function + parse_expr( + """ + (fn (%x) => { %x })(0) + """ + ) + + # curried function + parse_expr( + """ + let %curried_mult = + fn (%x) => { + fn (%y) => { + %x * %y + } + }; + %curried_mult(0); + %curried_mult(0)(0) + """ + ) From b8bb3a740a033ffbdae43376ec6999212002410a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 21:38:01 -0700 Subject: [PATCH 11/64] identType, callType, tupleType, comment out unimplemented portions of grammar, other minor details --- python/tvm/relay/grammar/Relay.g4 | 20 +++--- python/tvm/relay/parser.py | 41 +++++++---- tests/python/relay/test_ir_parser.py | 103 +++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index a582855b259f..b233d61b032c 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -79,13 +79,13 @@ expr | '{' expr '}' ';' expr # seq // mutable update - | ident '=' expr # writeRef - | expr '^' # readRef + // | ident '=' expr # writeRef + // | expr '^' # readRef | ident # identExpr | scalar # scalarExpr - | expr '.' INT # project - | 'debug' # debug + // | expr '.' INT # project + // | 'debug' # debug ; func: 'fn' paramList '=>' type_? body ; @@ -95,18 +95,18 @@ paramList: '(' (param (',' param)*)? ')' ; param: ident (':' type_)? ; type_ - : '(' type_ ')' # parensType - | type_ op=('*'|'/') type_ # binOpType - | type_ op=('+'|'-') type_ # binOpType - | '(' ')' # tupleType + // : '(' type_ ')' # parensType + // | type_ op=('*'|'/') type_ # binOpType + // | type_ op=('+'|'-') type_ # binOpType + : '(' ')' # tupleType | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType | identType # identTypeType | identType '(' type_ (',' type_)* ')' # callType | identType '[' type_ (',' type_)* ']' # callType // Mut, Int, UInt, Float, Bool, Tensor - | type_ '.' INT # projectType - | INT # dimLitType + // | type_ '.' INT # projectType + | INT # intType | '_' # incompleteType ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 8a8e9afff81d..6ce789886cd8 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -313,7 +313,7 @@ def visitIfElse(self, ctx): # Types def visitIdentType(self, ctx): - # type: (RelayParser.IdentTypeContext) -> str + # type: (RelayParser.IdentTypeContext) -> Union[relay.TensorType, str] ident_type = ctx.CNAME().getText() if not ident_type[0].isupper(): @@ -322,40 +322,57 @@ def visitIdentType(self, ctx): builtin_type = TYPES.get(ident_type) if builtin_type is None: - # TODO: is this correct? - return ident_type + raise ParseError("Unknown builtin type.") else: - return builtin_type + return relay.TensorType([], builtin_type) def visitCallType(self, ctx): # type: (RelayParser.CallTypeContext) -> str - ident_type = ctx.identType().CNAME() + ident_type = ctx.identType().CNAME().getText() - args = self.visit_list(ctx.type_()) + args = [str(arg) for arg in self.visit_list(ctx.type_())] + + if not args: + raise ParseError("Type-level functions must have arguments!") if ident_type == "Int": + print('ident_type == "Int"') if len(args) > 2: raise ParseError("Int may have at most 2 arguments.") - return "int" + "x".join(args) - if ident_type == "UInt": + tvm_type = "int" + "x".join(args) + elif ident_type == "UInt": + print('ident_type == "UInt"') if len(args) > 2: raise ParseError("UInt may have at most 2 arguments.") - return "uint" + "x".join(args) + tvm_type = "uint" + "x".join(args) elif ident_type == "Float": + print('ident_type == "Float"') if len(args) > 2: raise ParseError("Float may have at most 2 arguments.") - return "float" + "x".join(args) + tvm_type = "float" + "x".join(args) elif ident_type == "Bool": + print('ident_type == "Bool"') if len(args) > 1: - raise ParseError("Float may have at most 1 argument.") - return "bool" + "x".join(args) + raise ParseError("Bool may have at most 1 argument.") + # can't use bool, because ffi doesn't convert anything after bool + # bool is sugar for uint1 anyway + tvm_type = "uint1x" + args[0] elif ident_type == "Mut": raise ParseError("Mutation is unimplemented.") elif ident_type == "Tensor": raise ParseError("Tensors are unimplemented.") else: + print(ident_type) + print(type(ident_type)) raise ParseError("Unrecognized type-level function.") + return relay.TensorType([], tvm_type) + + def visitTupleType(self, ctx): + # type: (RelayParser.TupleTypeContext) -> relay.TupleType + + return relay.TupleType(self.visit_list(ctx.type_())) + def make_parser(data): # type: (str) -> RelayParser input_stream = InputStream(data) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 38f516c74f5d..e695e3382423 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -3,6 +3,7 @@ from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program from tvm.relay.ir_pass import alpha_equal from nose.tools import nottest, raises +from typing import Union BINARY_OPS = { "*": relay.multiply, @@ -17,12 +18,45 @@ "!=": relay.not_equal, } +TYPES = { + "Int8", + "Int16", + "Int32", + "Int64", + + "UInt8", + "UInt16", + "UInt32", + "UInt64", + + "Float16", + "Float32", + "Float64", + + "Bool", +} + +CALL_TYPES = { + "Int": 2, + "UInt": 2, + "Float": 2, + "Bool": 1, +} + def get_scalar(x): + # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() def to_constant(x): + # type: (Union[float, int, bool]) -> relay.Constant return relay.Constant(tvm.nd.array(x)) +def to_tensor_type(x): + # type: (str) -> relay.TensorType + return relay.TensorType([], x) + +int64 = to_tensor_type("int64") + UNIT = relay.Tuple([]) TYPE_HOLE = relay.IncompleteType() @@ -43,6 +77,10 @@ def test_float_literal(): assert get_scalar(parse_expr("-10.0")) == -10.0 # scientific notation + assert get_scalar(parse_expr("1e-1")) == 1e-1 + assert get_scalar(parse_expr("1e+1")) == 1e+1 + assert get_scalar(parse_expr("1E-1")) == 1E-1 + assert get_scalar(parse_expr("1E+1")) == 1E+1 assert get_scalar(parse_expr("1.0e-1")) == 1.0e-1 assert get_scalar(parse_expr("1.0e+1")) == 1.0e+1 assert get_scalar(parse_expr("1.0E-1")) == 1.0E-1 @@ -226,6 +264,7 @@ def test_call(): %constant() """ ) + # assert alpha_equal( # parse_expr( # """ @@ -277,3 +316,67 @@ def test_call(): %curried_mult(0)(0) """ ) + +# Types + +def test_builtin_types(): + for builtin_type in TYPES: + parse_expr("let %_ : {} = (); ()".format(builtin_type)) + +def test_call_type(): + # tests e.g. + # let %_ : Int(0) = (); () + # let %_ : Int(0, 1) = (); () + for call_type, arity in CALL_TYPES.items(): + for i in range(1, arity + 1): + # custom tuple printing to avoid hanging comma for one-tuples + tup = "(" + ",".join([str(num) for num in range(i)]) + ")" + print("let %_ : {}{} = (); ()".format(call_type, tup)) + parse_expr("let %_ : {}{} = (); ()".format(call_type, tup)) + +@nottest +def test_function_type(): + assert False + +def test_type_annotation(): + assert False + +def test_tuple_type(): + assert alpha_equal( + parse_expr( + """ + let %_: () = (); () + """), + relay.Let( + relay.Var("_"), + UNIT, + UNIT, + relay.TupleType([]) + ) + ) + + assert alpha_equal( + parse_expr( + """ + let %x: (Int64,) = (0,); () + """), + relay.Let( + relay.Var("x"), + relay.Tuple([to_constant(0)]), + UNIT, + relay.TupleType([int64]) + ) + ) + + assert alpha_equal( + parse_expr( + """ + let %x: (Int64, Int64) = (0, 1); () + """), + relay.Let( + relay.Var("x"), + relay.Tuple([to_constant(0), to_constant(1)]), + UNIT, + relay.TupleType([int64, int64]) + ) + ) From d4fb94078a2009a4087d4c0a70a6aa2804e28cd3 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 9 Oct 2018 21:43:43 -0700 Subject: [PATCH 12/64] function annotations --- tests/python/relay/test_ir_parser.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index e695e3382423..7ad016728be6 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -220,6 +220,12 @@ def test_func(): assert isinstance(parse_expr("fn (%x, %y) => { %x + %y }"), relay.Function) + # annotations + + id_func_annotated = parse_expr("fn (%x: Int64) => Int64 { %x }") + assert id_func_annotated.params[0].type == int64 + assert id_func_annotated.ret_type == int64 + @nottest def test_defn(): id_defn = parse_prog("def @id(%x) => { %x }") @@ -338,9 +344,6 @@ def test_call_type(): def test_function_type(): assert False -def test_type_annotation(): - assert False - def test_tuple_type(): assert alpha_equal( parse_expr( From e6e2bf8c8e2c87a2efe89f5532f8bbde3edc6542 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Wed, 10 Oct 2018 22:08:36 -0700 Subject: [PATCH 13/64] funcType and syntax changes (in particular => to ->) --- python/tvm/relay/grammar/Relay.g4 | 21 +++++---- python/tvm/relay/parser.py | 10 +++- tests/python/relay/test_ir_parser.py | 69 ++++++++++++++++++++++------ 3 files changed, 74 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index b233d61b032c..bbc311e7ef85 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -88,8 +88,8 @@ expr // | 'debug' # debug ; -func: 'fn' paramList '=>' type_? body ; -defn: 'def' ident paramList '=>' type_? body ; +func: 'fn' paramList '->' type_? body ; +defn: 'def' ident paramList '->' type_? body ; paramList: '(' (param (',' param)*)? ')' ; param: ident (':' type_)? ; @@ -98,16 +98,17 @@ type_ // : '(' type_ ')' # parensType // | type_ op=('*'|'/') type_ # binOpType // | type_ op=('+'|'-') type_ # binOpType - : '(' ')' # tupleType - | '(' type_ ',' ')' # tupleType - | '(' type_ (',' type_)+ ')' # tupleType - | identType # identTypeType - | identType '(' type_ (',' type_)* ')' # callType - | identType '[' type_ (',' type_)* ']' # callType + : '(' ')' # tupleType + | '(' type_ ',' ')' # tupleType + | '(' type_ (',' type_)+ ')' # tupleType + | identType # identTypeType + | identType '(' (type_ (',' type_)*)? ')' # callType + | identType '[' (type_ (',' type_)*)? ']' # callType + | '(' (type_ (',' type_)*)? ')' '->' type_ # funcType // Mut, Int, UInt, Float, Bool, Tensor // | type_ '.' INT # projectType - | INT # intType - | '_' # incompleteType + | INT # intType + | '_' # incompleteType ; identType: CNAME ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 6ce789886cd8..bd7e1b90163f 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -370,9 +370,17 @@ def visitCallType(self, ctx): def visitTupleType(self, ctx): # type: (RelayParser.TupleTypeContext) -> relay.TupleType - return relay.TupleType(self.visit_list(ctx.type_())) + def visitFuncType(self, ctx): + # type: (RelayParser.FuncTypeContext) -> relay.FuncType + types = self.visit_list(ctx.type_()) + + arg_types = types[:-1] + ret_type = types[-1] + + return relay.FuncType(arg_types, ret_type, [], None) + def make_parser(data): # type: (str) -> RelayParser input_stream = InputStream(data) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7ad016728be6..2f8f02857ebb 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -192,7 +192,7 @@ def test_func(): # TODO(@jmp): get function alpha eqs to work # assert alpha_equal( - # parse_expr("fn (%x) => { %x }"), + # parse_expr("fn (%x) -> { %x }"), # relay.Function( # [relay.Param(relay.Var("x"), TYPE_HOLE)], # TYPE_HOLE, @@ -202,7 +202,7 @@ def test_func(): # ) # assert alpha_equal( - # parse_expr("fn (%x, %y) => { %x + %y }"), + # parse_expr("fn (%x, %y) -> { %x + %y }"), # relay.Function( # [relay.Param(relay.Var("x"), TYPE_HOLE), # relay.Param(relay.Var("y"), TYPE_HOLE)], @@ -212,23 +212,23 @@ def test_func(): # ) # ) - id_func = parse_expr("fn (%x) => { %x }") + id_func = parse_expr("fn (%x) -> { %x }") assert isinstance(id_func, relay.Function) assert id_func.params[0].var.name_hint == "x" assert isinstance(id_func.params[0].type, relay.IncompleteType) assert id_func.params[0].var == id_func.body - assert isinstance(parse_expr("fn (%x, %y) => { %x + %y }"), relay.Function) + assert isinstance(parse_expr("fn (%x, %y) -> { %x + %y }"), relay.Function) # annotations - id_func_annotated = parse_expr("fn (%x: Int64) => Int64 { %x }") + id_func_annotated = parse_expr("fn (%x: Int64) -> Int64 { %x }") assert id_func_annotated.params[0].type == int64 assert id_func_annotated.ret_type == int64 @nottest def test_defn(): - id_defn = parse_prog("def @id(%x) => { %x }") + id_defn = parse_prog("def @id(%x) -> { %x }") assert isinstance(id_defn, Program) def test_ifelse(): @@ -266,7 +266,7 @@ def test_call(): # 0 args parse_expr( """ - let %constant = fn () => { 0 }; + let %constant = fn () -> { 0 }; %constant() """ ) @@ -274,7 +274,7 @@ def test_call(): # assert alpha_equal( # parse_expr( # """ - # let %constant = fn () => { 0 }; + # let %constant = fn () -> { 0 }; # %constant() # """ # ), @@ -289,7 +289,7 @@ def test_call(): # 1 arg parse_expr( """ - let %id = fn (%x) => { %x }; + let %id = fn (%x) -> { %x }; %id(1) """ ) @@ -297,7 +297,7 @@ def test_call(): # 2 args parse_expr( """ - let %multiply = fn (%x, %y) => { %x * %y }; + let %multiply = fn (%x, %y) -> { %x * %y }; %multiply(0, 0) """ ) @@ -305,7 +305,7 @@ def test_call(): # anonymous function parse_expr( """ - (fn (%x) => { %x })(0) + (fn (%x) -> { %x })(0) """ ) @@ -313,8 +313,8 @@ def test_call(): parse_expr( """ let %curried_mult = - fn (%x) => { - fn (%y) => { + fn (%x) -> { + fn (%y) -> { %x * %y } }; @@ -340,9 +340,48 @@ def test_call_type(): print("let %_ : {}{} = (); ()".format(call_type, tup)) parse_expr("let %_ : {}{} = (); ()".format(call_type, tup)) -@nottest def test_function_type(): - assert False + assert alpha_equal( + parse_expr( + """ + let %_: () -> Int64 = fn () -> Int64 { 0 }; () + """ + ), + relay.Let( + relay.Var("_"), + relay.Function([], int64, to_constant(0), []), + UNIT, + relay.FuncType([], int64, [], []) + ) + ) + + assert alpha_equal( + parse_expr( + """ + let %_: (Int64) -> Int64 = fn (%x: Int64) -> Int64 { 0 }; () + """ + ), + relay.Let( + relay.Var("_"), + relay.Function([relay.Param(relay.Var("x"), int64)], int64, to_constant(0), []), + UNIT, + relay.FuncType([int64], int64, [], []) + ) + ) + + assert alpha_equal( + parse_expr( + """ + let %_: (Int64, Int64) -> Int64 = fn (%x: Int64, %y: Int64) -> Int64 { 0 }; () + """ + ), + relay.Let( + relay.Var("_"), + relay.Function([relay.Param(relay.Var("x"), int64), relay.Param(relay.Var("y"), int64)], int64, to_constant(0), []), + UNIT, + relay.FuncType([int64, int64], int64, [], []) + ) + ) def test_tuple_type(): assert alpha_equal( From 390dbea0786756cc6db8682c62222dde45e1290a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sat, 13 Oct 2018 15:30:27 -0700 Subject: [PATCH 14/64] attempt at adding antlr4 to cmake. remove outdated grammar rules. update tests --- CMakeLists.txt | 21 +++++++++++++++++++++ python/tvm/relay/grammar/Relay.g4 | 2 -- tests/python/relay/test_ir_parser.py | 5 +---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 98bbc5b650d3..a9c567024e88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -278,6 +278,27 @@ else(INSTALL_DEV) ) endif(INSTALL_DEV) +# ANTLR4 build definitions +# find_program(ANTLR4 antlr4) + +# set(RELAY_PARSER_DIR +# ${CMAKE_CURRENT_SOURCE_DIR}/relay/python/relay/parser) + +# set(RELAY_PARSER +# ${RELAY_PARSER_DIR}/RelayVisitor.py +# ${RELAY_PARSER_DIR}/RelayParser.py +# ${RELAY_PARSER_DIR}/RelayLexer.py) + +# if(ANTLR4) +# # Generate ANTLR grammar for parsing. +# add_custom_command(OUTPUT ${RELAY_PARSER} +# COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR} +# DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 +# WORKING_DIRECTORY ${RELAY_PARSER_DIR}) +# endif() + +# add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) + # More target definitions if(MSVC) target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index bbc311e7ef85..269a298e0f53 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -102,11 +102,9 @@ type_ | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType | identType # identTypeType - | identType '(' (type_ (',' type_)*)? ')' # callType | identType '[' (type_ (',' type_)*)? ']' # callType | '(' (type_ (',' type_)*)? ')' '->' type_ # funcType // Mut, Int, UInt, Float, Bool, Tensor - // | type_ '.' INT # projectType | INT # intType | '_' # incompleteType ; diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 2f8f02857ebb..5f91f93ec7b4 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -335,10 +335,7 @@ def test_call_type(): # let %_ : Int(0, 1) = (); () for call_type, arity in CALL_TYPES.items(): for i in range(1, arity + 1): - # custom tuple printing to avoid hanging comma for one-tuples - tup = "(" + ",".join([str(num) for num in range(i)]) + ")" - print("let %_ : {}{} = (); ()".format(call_type, tup)) - parse_expr("let %_ : {}{} = (); ()".format(call_type, tup)) + parse_expr("let %_ : {}{} = (); ()".format(call_type, range(i))) def test_function_type(): assert alpha_equal( From 2174b6a107ee9eac53a04f45803a226433fdedc6 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sat, 13 Oct 2018 18:56:49 -0700 Subject: [PATCH 15/64] add antlr to cmake. add new unused grammar rules. refactoring --- CMakeLists.txt | 41 +++++----- python/tvm/relay/grammar/Relay.g4 | 10 ++- python/tvm/relay/parser.py | 116 ++++++++++++++++++--------- tests/python/relay/test_ir_parser.py | 83 ++++++++++--------- 4 files changed, 154 insertions(+), 96 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a9c567024e88..f9fd78cdb950 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -279,25 +279,30 @@ else(INSTALL_DEV) endif(INSTALL_DEV) # ANTLR4 build definitions -# find_program(ANTLR4 antlr4) - -# set(RELAY_PARSER_DIR -# ${CMAKE_CURRENT_SOURCE_DIR}/relay/python/relay/parser) - -# set(RELAY_PARSER -# ${RELAY_PARSER_DIR}/RelayVisitor.py -# ${RELAY_PARSER_DIR}/RelayParser.py -# ${RELAY_PARSER_DIR}/RelayLexer.py) - -# if(ANTLR4) -# # Generate ANTLR grammar for parsing. -# add_custom_command(OUTPUT ${RELAY_PARSER} -# COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR} -# DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 -# WORKING_DIRECTORY ${RELAY_PARSER_DIR}) -# endif() +find_program(ANTLR4 antlr4) + +set(RELAY_PARSER_DIR + ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) + +set(RELAY_PARSER + ${RELAY_PARSER_DIR}/py2/RelayVisitor.py + ${RELAY_PARSER_DIR}/py2/RelayParser.py + ${RELAY_PARSER_DIR}/py2/RelayLexer.py + + ${RELAY_PARSER_DIR}/py3/RelayVisitor.py + ${RELAY_PARSER_DIR}/py3/RelayParser.py + ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + +if(ANTLR4) + # Generate ANTLR grammar for parsing. + add_custom_command(OUTPUT ${RELAY_PARSER} + COMMAND antlr4 -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 + WORKING_DIRECTORY ${RELAY_PARSER_DIR}) +endif() -# add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) +add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) # More target definitions if(MSVC) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 269a298e0f53..39fa336e5ccc 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -94,7 +94,15 @@ defn: 'def' ident paramList '->' type_? body ; paramList: '(' (param (',' param)*)? ')' ; param: ident (':' type_)? ; +// TODO(@jmp): for improved type annotations +// typeAnno: ':' type_ ; +// returnAnno: (ident ':')? type_ ; + +// relations: 'where' relation (',' relation)* ; +// relation: ident '(' (type_ (',' type_)*)? ')' ; + type_ + // TODO(@jmp): for shape expressions // : '(' type_ ')' # parensType // | type_ op=('*'|'/') type_ # binOpType // | type_ op=('+'|'-') type_ # binOpType @@ -106,7 +114,7 @@ type_ | '(' (type_ (',' type_)*)? ')' '->' type_ # funcType // Mut, Int, UInt, Float, Bool, Tensor | INT # intType - | '_' # incompleteType + // | '_' # incompleteType ; identType: CNAME ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index bd7e1b90163f..cd2e281c53fe 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -15,6 +15,14 @@ from .grammar.py3.RelayParser import RelayParser from .grammar.py3.RelayLexer import RelayLexer +Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) + +class ParseError(Exception): + def __init__(self, message): + # type: (str) -> None + super(ParseError, self).__init__() + self.message = message + BINARY_OPS = { RelayParser.MUL: relay.multiply, RelayParser.DIV: relay.divide, @@ -46,13 +54,67 @@ "Bool": "bool", } -Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) +def int_type_call(args): + # type: (List[relay.Expr]) -> relay.TensorType + if len(args) > 2: + raise ParseError("Int may have at most 2 arguments.") -class ParseError(Exception): - def __init__(self, message): - # type: (str) -> None - super(ParseError, self).__init__() - self.message = message + str_args = [str(arg) for arg in args] + + return relay.TensorType([], "int" + "x".join(str_args)) + +def uint_type_call(args): + # type: (List[relay.Expr]) -> relay.TensorType + if len(args) > 2: + raise ParseError("UInt may have at most 2 arguments.") + + str_args = [str(arg) for arg in args] + + return relay.TensorType([], "uint" + "x".join(str_args)) + +def float_type_call(args): + # type: (List[relay.Expr]) -> relay.TensorType + if len(args) > 2: + raise ParseError("Float may have at most 2 arguments.") + + str_args = [str(arg) for arg in args] + + return relay.TensorType([], "float" + "x".join(str_args)) + +def bool_type_call(args): + # type: (List[relay.Expr]) -> relay.TensorType + if len(args) > 1: + raise ParseError("Bool may have at most 1 argument.") + + # can't use bool, because ffi doesn't convert anything after bool + # bool is sugar for uint1 anyway + str_args = [str(arg) for arg in args] + + return relay.TensorType([], "uint1x" + str_args[0]) + +# TODO(@jmp): Unused. +def tensor_type_call(args): + # type: (List[relay.Expr]) -> Union[relay.Expr, relay.TensorType] + if len(args) > 2: + raise ParseError("Tensor may have at most 2 arguments.") + elif len(args) == 0: + print("Warning. Generic Tensor type unimplemented. Treating as Expr.") + return relay.Expr() + elif len(args) == 1: + raise ParseError("Generic Shape type unimplemented.") + elif len(args) == 2: + dtype = args[0] + shape = args[1] + + return relay.TensorType(shape, dtype) + +TYPE_FUNCS = { + "Int": int_type_call, + "UInt": uint_type_call, + "Float": float_type_call, + "Bool": bool_type_call, + # "Tensor": tensor_type_call, +} T = TypeVar("T") Scope = Deque[Tuple[str, T]] @@ -322,51 +384,25 @@ def visitIdentType(self, ctx): builtin_type = TYPES.get(ident_type) if builtin_type is None: - raise ParseError("Unknown builtin type.") + raise ParseError("Unknown builtin type: {}".format(ident_type)) else: return relay.TensorType([], builtin_type) def visitCallType(self, ctx): - # type: (RelayParser.CallTypeContext) -> str + # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] ident_type = ctx.identType().CNAME().getText() - args = [str(arg) for arg in self.visit_list(ctx.type_())] + args = self.visit_list(ctx.type_()) if not args: raise ParseError("Type-level functions must have arguments!") - if ident_type == "Int": - print('ident_type == "Int"') - if len(args) > 2: - raise ParseError("Int may have at most 2 arguments.") - tvm_type = "int" + "x".join(args) - elif ident_type == "UInt": - print('ident_type == "UInt"') - if len(args) > 2: - raise ParseError("UInt may have at most 2 arguments.") - tvm_type = "uint" + "x".join(args) - elif ident_type == "Float": - print('ident_type == "Float"') - if len(args) > 2: - raise ParseError("Float may have at most 2 arguments.") - tvm_type = "float" + "x".join(args) - elif ident_type == "Bool": - print('ident_type == "Bool"') - if len(args) > 1: - raise ParseError("Bool may have at most 1 argument.") - # can't use bool, because ffi doesn't convert anything after bool - # bool is sugar for uint1 anyway - tvm_type = "uint1x" + args[0] - elif ident_type == "Mut": - raise ParseError("Mutation is unimplemented.") - elif ident_type == "Tensor": - raise ParseError("Tensors are unimplemented.") - else: - print(ident_type) - print(type(ident_type)) - raise ParseError("Unrecognized type-level function.") + func_type = TYPE_FUNCS.get(ident_type)(args) - return relay.TensorType([], tvm_type) + if func_type is None: + raise ParseError("Unknown type-level function: `{}`".format(ident_type)) + else: + return func_type def visitTupleType(self, ctx): # type: (RelayParser.TupleTypeContext) -> relay.TupleType diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 5f91f93ec7b4..9511df62671b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -2,6 +2,7 @@ from tvm import relay from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program from tvm.relay.ir_pass import alpha_equal +from tvm.relay.expr import pretty_print from nose.tools import nottest, raises from typing import Union @@ -55,6 +56,9 @@ def to_tensor_type(x): # type: (str) -> relay.TensorType return relay.TensorType([], x) +X = relay.Var("x") +Y = relay.Var("y") + int64 = to_tensor_type("int64") UNIT = relay.Tuple([]) @@ -110,7 +114,6 @@ def test_op_assoc(): assert alpha_equal(parse_expr("1 * 1 + 1 < 1 == 1"), parse_expr("(((1 * 1) + 1) < 1) == 1")) assert alpha_equal(parse_expr("1 == 1 < 1 + 1 * 1"), parse_expr("1 == (1 < (1 + (1 * 1)))")) -@nottest def test_vars(): # temp vars won't work b/c they start with a digit # # temp var @@ -119,15 +122,14 @@ def test_vars(): # assert temp_var.name == "1" # var - # var = parse_expr("let %foo = 0; %foo") - var = parse_expr("%foo") + var = parse_expr("let %foo = (); %foo") assert isinstance(var.body, relay.Var) - assert var.body.name == "foo" + assert var.body.name_hint == "foo" # global var global_var = parse_expr("@foo") assert isinstance(global_var, relay.GlobalVar) - assert global_var.name == "foo" + assert global_var.name_hint == "foo" # operator id op = parse_expr("foo") @@ -189,43 +191,50 @@ def test_tuple(): assert alpha_equal(parse_expr("(0, 1, 2)"), relay.Tuple([to_constant(0), to_constant(1), to_constant(2)])) def test_func(): - # TODO(@jmp): get function alpha eqs to work - - # assert alpha_equal( - # parse_expr("fn (%x) -> { %x }"), - # relay.Function( - # [relay.Param(relay.Var("x"), TYPE_HOLE)], - # TYPE_HOLE, - # relay.Var("x"), - # [] - # ) - # ) - - # assert alpha_equal( - # parse_expr("fn (%x, %y) -> { %x + %y }"), - # relay.Function( - # [relay.Param(relay.Var("x"), TYPE_HOLE), - # relay.Param(relay.Var("y"), TYPE_HOLE)], - # TYPE_HOLE, - # relay.add(relay.Var("x"), relay.Var("y")), - # [] - # ) - # ) + # 0 args + assert alpha_equal( + parse_expr("fn () -> { 0 }"), + relay.Function( + [], + TYPE_HOLE, + to_constant(0), + [] + ) + ) - id_func = parse_expr("fn (%x) -> { %x }") - assert isinstance(id_func, relay.Function) - assert id_func.params[0].var.name_hint == "x" - assert isinstance(id_func.params[0].type, relay.IncompleteType) - assert id_func.params[0].var == id_func.body + # 1 arg + assert alpha_equal( + parse_expr("fn (%x) -> { %x }"), + relay.Function( + [relay.Param(X, TYPE_HOLE)], + TYPE_HOLE, + X, + [] + ) + ) - assert isinstance(parse_expr("fn (%x, %y) -> { %x + %y }"), relay.Function) + # 2 args + assert alpha_equal( + parse_expr("fn (%x, %y) -> { %x + %y }"), + relay.Function( + [relay.Param(X, TYPE_HOLE), + relay.Param(Y, TYPE_HOLE)], + TYPE_HOLE, + relay.add(X, Y), + [] + ) + ) # annotations + assert alpha_equal( + parse_expr("fn (%x: Int64) -> Int64 { %x }"), + [relay.Param(X, int64)], + int64, + X, + [] + ) - id_func_annotated = parse_expr("fn (%x: Int64) -> Int64 { %x }") - assert id_func_annotated.params[0].type == int64 - assert id_func_annotated.ret_type == int64 - +# TODO(@jmp): Figure out why this is crashing. @nottest def test_defn(): id_defn = parse_prog("def @id(%x) -> { %x }") From 8a9c41e218e362b8f4546b34089f27db1aa3ca33 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 14 Oct 2018 16:20:55 -0700 Subject: [PATCH 16/64] fix alpha eq bug and change tests to use alpha eq. test refactoring --- python/tvm/relay/grammar/Relay.g4 | 4 +- src/relay/pass/alpha_eq.cc | 418 +++++++++++++++++++++++++++ tests/python/relay/test_ir_parser.py | 166 +++++++---- 3 files changed, 533 insertions(+), 55 deletions(-) create mode 100644 src/relay/pass/alpha_eq.cc diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 39fa336e5ccc..1b252cd43c75 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -73,9 +73,9 @@ expr // sequencing | 'let' MUT? ident (':' type_)? '=' expr ';' expr # seq - // sugar for let _ = expr; expr + // sugar for let %_ = expr; expr | expr ';' expr # seq - // sugar for let _ = expr; expr + // sugar for let %_ = expr; expr | '{' expr '}' ';' expr # seq // mutable update diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc new file mode 100644 index 000000000000..84fb50a9f978 --- /dev/null +++ b/src/relay/pass/alpha_eq.cc @@ -0,0 +1,418 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pass/alpha_eq.cc + * \brief Check that two type are syntactically equal up to alpha equivalence. + */ +#include +#include +#include +#include "./type_visitor.h" +#include "tvm/relay/pass.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { + if (lhs.defined() != rhs.defined()) { + return false; + } else if (lhs.same_as(rhs)) { + return true; + } else { + auto ldt = lhs->dtype; + auto rdt = rhs->dtype; + CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t s = GetDataSize(*lhs.operator->()); + return memcmp(lhs->data, rhs->data, s) == 0; + } else { + return false; + } + } +} + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + + TypeAlphaEq() : eq_map(), equal(true) {} + + void DataTypeEqual(const DataType& dt1, const DataType& dt2) { + if (dt1 != dt2) { + equal = false; + } + } + + void ShapeEqual(const Array& s1, const Array& s2) { + if (s1.size() != s2.size()) { + equal = false; + return; + } + for (size_t i = 0; i < s1.size(); ++i) { + if (!tvm::ir::Equal(s1[i], s2[i])) { + equal = false; + return; + } + } + } + + void VisitType_(const TensorTypeNode* tt1, const Type& t2) final { + if (const TensorTypeNode* tt2 = t2.as()) { + DataTypeEqual(tt1->dtype, tt2->dtype); + ShapeEqual(tt1->shape, tt2->shape); + } else { + equal = false; + } + } + + void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { + if (const IncompleteTypeNode* bt2 = t2.as()) { + equal = equal && bt1->kind == bt2->kind; + return; + } else { + equal = false; + } + } + + void VisitType_(const TypeVarNode* ti1, const Type& t2) final { + if (const TypeVarNode* ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + return; + } + + // Check that they are same kind + if (tid1->kind != tid2->kind) { + equal = false; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const FuncTypeNode* op, const Type& t2) final { + if (const FuncTypeNode* ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size() + || op->type_params.size() != ta2->type_params.size() + || op->type_constraints.size() != ta2->type_constraints.size()) { + equal = false; + return; + } + + // must visit params first so they are appropriate entered + // into equality map + for (size_t i = 0; i < op->type_params.size(); i++) { + eq_map.Set(op->type_params[i], ta2->type_params[i]); + this->VisitType(op->type_params[i], ta2->type_params[i]); + if (!equal) { + return; + } + } + + for (size_t i = 0; i < op->arg_types.size(); i++) { + this->VisitType(op->arg_types[i], ta2->arg_types[i]); + if (!equal) { + return; + } + } + + this->VisitType(op->ret_type, ta2->ret_type); + if (!equal) { + return; + } + + for (size_t i = 0; i < op->type_constraints.size(); i++) { + this->VisitType(op->type_constraints[i], ta2->type_constraints[i]); + if (!equal) { + return; + } + } + } else { + equal = false; + } + } + + void VisitType_(const TypeRelationNode* tr1, const Type& t2) final { + if (const TypeRelationNode* tr2 = t2.as()) { + if (tr1->func != tr2->func + || tr1->num_inputs != tr2->num_inputs + || tr1->attrs != tr2->attrs) { + equal = false; + return; + } + + if (tr1->args.size() != tr2->args.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < tr1->args.size(); i++) { + this->VisitType(tr1->args[i], tr2->args[i]); + if (!equal) { + return; + } + } + } else { + equal = false; + } + } + + void VisitType_(const TupleTypeNode* op, const Type& t2) final { + if (const TupleTypeNode* pt = t2.as()) { + if (op->fields.size() != pt->fields.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < op->fields.size(); i++) { + if (!equal) { + return; + } + this->VisitType(op->fields[i], pt->fields[i]); + } + } else { + equal = false; + } + } +}; + +bool AlphaEqual(const Type& t1, const Type& t2) { + if (t1.defined() != t2.defined()) { + return false; + } + + if (!t1.defined()) { + return true; + } + + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +struct AlphaEq : ExprFunctor { + public: + tvm::Map eq_map; + + bool equal; + AlphaEq() : eq_map(), equal(true) {} + + void VisitExpr_(const VarNode* e1, const Expr& e2) final { + if (const VarNode* id2 = e2.as()) { + auto local1 = GetRef(e1); + auto local2 = GetRef(id2); + // We handle open terms with this rule assuming variables are identical. + if (local1 == local2) { + equal = true; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(local1) != eq_map.end()) { + equal = equal && eq_map[local1] == local2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final { + if (const GlobalVarNode* g2 = e2.as()) { + equal = equal && g1 == g2; + } else { + equal = false; + } + } + + void VisitExpr_(const TupleNode* pl1, const Expr& e2) final { + Tuple prod1 = GetRef(pl1); + if (const TupleNode* pl2 = e2.as()) { + Tuple prod2 = GetRef(pl2); + if (prod1->fields.size() != prod2->fields.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < prod1->fields.size(); i++) { + this->VisitExpr(prod1->fields[i], prod2->fields[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { + if (const FunctionNode* func2 = e2.as()) { + if (func1->params.size() != func2->params.size()) { + equal = false; + return; + } + + if (func1->type_params.size() != func2->type_params.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < func1->params.size(); ++i) { + MergeVarDecl(func1->params[i], func2->params[i]); + } + + if (!equal) { + return; + } + + for (size_t i = 0U; i < func1->type_params.size(); i++) { + equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); + if (!equal) { + return; + } + } + + equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); + if (!equal) { + return; + } + + this->VisitExpr(func1->body, func2->body); + } else { + equal = false; + } + } + + void VisitExpr_(const CallNode* op, const Expr& e2) final { + if (const CallNode* call = e2.as()) { + this->VisitExpr(op->op, call->op); + + if (op->args.size() != call->args.size()) { + equal = false; + return; + } + + if (op->type_args.size() != call->type_args.size()) { + equal = false; + return; + } + + // checking attrs by pointer equality for now + equal = equal && (op->attrs == call->attrs); + if (!equal) { + return; + } + + for (size_t i = 0U; i < op->args.size(); i++) { + this->VisitExpr(op->args[i], call->args[i]); + } + + for (size_t i = 0U; i < op->type_args.size(); i++) { + equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]); + if (!equal) { + return; + } + } + } else { + equal = false; + } + } + + void VisitExpr_(const LetNode* op, const Expr& e2) final { + if (const LetNode* let = e2.as()) { + MergeVarDecl(op->var, let->var); + this->VisitExpr(op->value, let->value); + this->VisitExpr(op->body, let->body); + } else { + equal = false; + } + } + + void VisitExpr_(const IfNode* op, const Expr& e2) final { + if (const IfNode* i = e2.as()) { + VisitExpr(op->cond, i->cond); + VisitExpr(op->true_branch, i->true_branch); + VisitExpr(op->false_branch, i->false_branch); + } else { + equal = false; + } + } + + void VisitExpr_(const OpNode* op, const Expr& e2) final { + if (const OpNode* o = e2.as()) { + equal = equal && op->name == o->name; + } else { + equal = false; + } + } + + void VisitExpr_(const ConstantNode* op, const Expr& e2) final { + if (const ConstantNode* c = e2.as()) { + if (AlphaEqual(op->tensor_type(), c->tensor_type())) { + equal = equal && SameNDArray(op->data, c->data); + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final { + if (const TupleGetItemNode* proj = e2.as()) { + this->VisitExpr(op->tuple, proj->tuple); + equal = equal && (op->index == proj->index); + } else { + equal = false; + } + } + + private: + void MergeVarDecl(const Var& var1, const Var& var2) { + equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation); + if (!equal) { + return; + } + + eq_map.Set(var1, var2); + } +}; + +bool AlphaEqual(const Expr& e1, const Expr& e2) { + AlphaEq eq; + eq.VisitExpr(e1, e2); + return eq.equal; +} + +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("relay._make._alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Expr e1 = args[0]; + Expr e2 = args[1]; + *ret = AlphaEqual(e1, e2); + }); + +TVM_REGISTER_API("relay._make._type_alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = AlphaEqual(t1, t2); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 9511df62671b..305bf2168dfc 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -56,6 +56,7 @@ def to_tensor_type(x): # type: (str) -> relay.TensorType return relay.TensorType([], x) +_ = relay.Var("_") X = relay.Var("x") Y = relay.Var("y") @@ -114,6 +115,7 @@ def test_op_assoc(): assert alpha_equal(parse_expr("1 * 1 + 1 < 1 == 1"), parse_expr("(((1 * 1) + 1) < 1) == 1")) assert alpha_equal(parse_expr("1 == 1 < 1 + 1 * 1"), parse_expr("1 == (1 < (1 + (1 * 1)))")) +@nottest def test_vars(): # temp vars won't work b/c they start with a digit # # temp var @@ -141,7 +143,7 @@ def test_let(): parse_expr("let %x = 1; ()"), relay.Let( - relay.Var("x"), + X, to_constant(1), UNIT, TYPE_HOLE @@ -153,7 +155,7 @@ def test_seq(): parse_expr("(); ()"), relay.Let( - relay.Var("_"), + _, UNIT, UNIT, TYPE_HOLE) @@ -161,7 +163,7 @@ def test_seq(): assert alpha_equal( parse_expr("{ (); () }; ()"), - + # Can't use _ constant, because the _'s are different. relay.Let( relay.Var("_"), relay.Let(relay.Var("_"), UNIT, UNIT, TYPE_HOLE), @@ -228,10 +230,12 @@ def test_func(): # annotations assert alpha_equal( parse_expr("fn (%x: Int64) -> Int64 { %x }"), - [relay.Param(X, int64)], - int64, - X, - [] + relay.Function( + [relay.Param(X, int64)], + int64, + X, + [] + ) ) # TODO(@jmp): Figure out why this is crashing. @@ -273,63 +277,117 @@ def test_ifelse_scope(): def test_call(): # 0 args - parse_expr( + constant = relay.Var("constant") + assert alpha_equal( + parse_expr( """ let %constant = fn () -> { 0 }; %constant() """ + ), + relay.Let( + constant, + relay.Function([], TYPE_HOLE, to_constant(0), []), + relay.Call(constant, [], None, None), + TYPE_HOLE + ) ) - # assert alpha_equal( - # parse_expr( - # """ - # let %constant = fn () -> { 0 }; - # %constant() - # """ - # ), - # relay.Let( - # relay.Var("constant"), - # relay.Function([], TYPE_HOLE, to_constant(0), []), - # relay.Call(relay.Var("constant"), [], None, None), - # TYPE_HOLE - # ) - # ) - # 1 arg - parse_expr( - """ - let %id = fn (%x) -> { %x }; - %id(1) - """ + id_var = relay.Var("id") + assert alpha_equal( + parse_expr( + """ + let %id = fn (%x) -> { %x }; + %id(1) + """ + ), + relay.Let( + id_var, + relay.Function([relay.Param(X, TYPE_HOLE)], TYPE_HOLE, X, []), + relay.Call(id_var, [to_constant(1)], None, None), + TYPE_HOLE + ) ) # 2 args - parse_expr( + multiply = relay.Var("multiply") + assert alpha_equal( + parse_expr( """ let %multiply = fn (%x, %y) -> { %x * %y }; %multiply(0, 0) """ + ), + relay.Let( + multiply, + relay.Function( + [relay.Param(X, TYPE_HOLE), relay.Param(Y, TYPE_HOLE)], + TYPE_HOLE, + relay.multiply(X, Y), + [] + ), + relay.Call(multiply, [to_constant(0), to_constant(0)], None, None), + TYPE_HOLE + ) ) # anonymous function - parse_expr( + assert alpha_equal( + parse_expr( """ (fn (%x) -> { %x })(0) """ + ), + relay.Call( + relay.Function( + [relay.Param(X, TYPE_HOLE)], + TYPE_HOLE, + X, + [] + ), + [to_constant(0)], + None, + None + ) ) # curried function - parse_expr( - """ - let %curried_mult = - fn (%x) -> { - fn (%y) -> { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """ + curried_mult = relay.Var("curried_mult") + alpha_equal( + parse_expr( + """ + let %curried_mult = + fn (%x) -> { + fn (%y) -> { + %x * %y + } + }; + %curried_mult(0); + %curried_mult(0)(0) + """ + ), + relay.Let( + curried_mult, + relay.Function( + [relay.Param(X, TYPE_HOLE)], + TYPE_HOLE, + relay.Function( + [relay.Param(Y, TYPE_HOLE)], + TYPE_HOLE, + relay.multiply(X, Y), + [] + ), + [] + ), + relay.Let( + _, + relay.Call(curried_mult, [to_constant(0)], None, None), + relay.Call(relay.Call(curried_mult, [to_constant(0)], None, None), [to_constant(0)], None, None), + TYPE_HOLE + ), + TYPE_HOLE + ) ) # Types @@ -340,11 +398,13 @@ def test_builtin_types(): def test_call_type(): # tests e.g. - # let %_ : Int(0) = (); () - # let %_ : Int(0, 1) = (); () + # let %_ : Int[0] = (); () + # let %_ : Int[0, 1] = (); () for call_type, arity in CALL_TYPES.items(): - for i in range(1, arity + 1): - parse_expr("let %_ : {}{} = (); ()".format(call_type, range(i))) + args = [] + for i in range(arity): + args.append(i) + parse_expr("let %_ : {}{} = (); ()".format(call_type, args)) def test_function_type(): assert alpha_equal( @@ -354,7 +414,7 @@ def test_function_type(): """ ), relay.Let( - relay.Var("_"), + _, relay.Function([], int64, to_constant(0), []), UNIT, relay.FuncType([], int64, [], []) @@ -368,8 +428,8 @@ def test_function_type(): """ ), relay.Let( - relay.Var("_"), - relay.Function([relay.Param(relay.Var("x"), int64)], int64, to_constant(0), []), + _, + relay.Function([relay.Param(X, int64)], int64, to_constant(0), []), UNIT, relay.FuncType([int64], int64, [], []) ) @@ -382,8 +442,8 @@ def test_function_type(): """ ), relay.Let( - relay.Var("_"), - relay.Function([relay.Param(relay.Var("x"), int64), relay.Param(relay.Var("y"), int64)], int64, to_constant(0), []), + _, + relay.Function([relay.Param(X, int64), relay.Param(Y, int64)], int64, to_constant(0), []), UNIT, relay.FuncType([int64, int64], int64, [], []) ) @@ -396,7 +456,7 @@ def test_tuple_type(): let %_: () = (); () """), relay.Let( - relay.Var("_"), + _, UNIT, UNIT, relay.TupleType([]) @@ -409,7 +469,7 @@ def test_tuple_type(): let %x: (Int64,) = (0,); () """), relay.Let( - relay.Var("x"), + X, relay.Tuple([to_constant(0)]), UNIT, relay.TupleType([int64]) @@ -422,7 +482,7 @@ def test_tuple_type(): let %x: (Int64, Int64) = (0, 1); () """), relay.Let( - relay.Var("x"), + X, relay.Tuple([to_constant(0), to_constant(1)]), UNIT, relay.TupleType([int64, int64]) From 41bb1eda8257c5fa1908bec446a4756d03a37ac0 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 15 Oct 2018 12:32:50 -0700 Subject: [PATCH 17/64] switch parser to use None instead of IncompleteType. revise alpha_eq to check for empty types everywhere --- python/tvm/relay/parser.py | 5 +-- src/relay/pass/alpha_eq.cc | 15 ++++++- tests/python/relay/test_ir_parser.py | 59 ++++++++++++---------------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index cd2e281c53fe..cdb50ba5c479 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -208,11 +208,10 @@ def visit_list(self, ctx_list): # type: (List[ParserRuleContext]) -> List[relay.Expr] return [self.visit(ctx) for ctx in ctx_list] - # TODO(@jmp): Include kind environment to set IncompleteType appropriately. def getType_(self, ctx): - # type: (Optional[RelayParser.Type_Context]) -> relay.Type + # type: (Optional[RelayParser.Type_Context]) -> Optional[relay.Type] if ctx is None: - return relay.IncompleteType() + return None else: return self.visit(ctx) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 84fb50a9f978..095ef7f637b0 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -69,7 +69,7 @@ struct TypeAlphaEq : TypeVisitor { void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { if (const IncompleteTypeNode* bt2 = t2.as()) { - equal = equal && bt1->kind == bt2->kind; + equal = equal && bt1 == bt2; return; } else { equal = false; @@ -206,6 +206,16 @@ bool AlphaEqual(const Type& t1, const Type& t2) { return aeq.equal; } +bool NullableAlphaEqual(const Type& t1, const Type& t2) { + if (t1.defined() != t2.defined()) + return false; + + if (!t1.defined()) + return true; + + return AlphaEqual(t1, t2); +} + struct AlphaEq : ExprFunctor { public: tvm::Map eq_map; @@ -287,7 +297,8 @@ struct AlphaEq : ExprFunctor { } } - equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); + equal = equal && NullableAlphaEqual(func1->ret_type, func2->ret_type); + if (!equal) { return; } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 305bf2168dfc..a79135862cb0 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -63,7 +63,6 @@ def to_tensor_type(x): int64 = to_tensor_type("int64") UNIT = relay.Tuple([]) -TYPE_HOLE = relay.IncompleteType() def test_int_literal(): assert isinstance(parse_expr("1"), relay.Constant) @@ -145,8 +144,7 @@ def test_let(): relay.Let( X, to_constant(1), - UNIT, - TYPE_HOLE + UNIT ) ) @@ -157,8 +155,7 @@ def test_seq(): relay.Let( _, UNIT, - UNIT, - TYPE_HOLE) + UNIT) ) assert alpha_equal( @@ -166,9 +163,8 @@ def test_seq(): # Can't use _ constant, because the _'s are different. relay.Let( relay.Var("_"), - relay.Let(relay.Var("_"), UNIT, UNIT, TYPE_HOLE), - UNIT, - TYPE_HOLE) + relay.Let(relay.Var("_"), UNIT, UNIT), + UNIT) ) @raises(ParseError) @@ -198,7 +194,7 @@ def test_func(): parse_expr("fn () -> { 0 }"), relay.Function( [], - TYPE_HOLE, + None, to_constant(0), [] ) @@ -208,8 +204,8 @@ def test_func(): assert alpha_equal( parse_expr("fn (%x) -> { %x }"), relay.Function( - [relay.Param(X, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(X, None)], + None, X, [] ) @@ -219,9 +215,9 @@ def test_func(): assert alpha_equal( parse_expr("fn (%x, %y) -> { %x + %y }"), relay.Function( - [relay.Param(X, TYPE_HOLE), - relay.Param(Y, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(X, None), + relay.Param(Y, None)], + None, relay.add(X, Y), [] ) @@ -287,9 +283,8 @@ def test_call(): ), relay.Let( constant, - relay.Function([], TYPE_HOLE, to_constant(0), []), - relay.Call(constant, [], None, None), - TYPE_HOLE + relay.Function([], None, to_constant(0), []), + relay.Call(constant, [], None, None) ) ) @@ -304,9 +299,8 @@ def test_call(): ), relay.Let( id_var, - relay.Function([relay.Param(X, TYPE_HOLE)], TYPE_HOLE, X, []), - relay.Call(id_var, [to_constant(1)], None, None), - TYPE_HOLE + relay.Function([relay.Param(X, None)], None, X, []), + relay.Call(id_var, [to_constant(1)], None, None) ) ) @@ -322,13 +316,12 @@ def test_call(): relay.Let( multiply, relay.Function( - [relay.Param(X, TYPE_HOLE), relay.Param(Y, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(X, None), relay.Param(Y, None)], + None, relay.multiply(X, Y), [] ), - relay.Call(multiply, [to_constant(0), to_constant(0)], None, None), - TYPE_HOLE + relay.Call(multiply, [to_constant(0), to_constant(0)], None, None) ) ) @@ -341,8 +334,8 @@ def test_call(): ), relay.Call( relay.Function( - [relay.Param(X, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(X, None)], + None, X, [] ), @@ -370,11 +363,11 @@ def test_call(): relay.Let( curried_mult, relay.Function( - [relay.Param(X, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(X, None)], + None, relay.Function( - [relay.Param(Y, TYPE_HOLE)], - TYPE_HOLE, + [relay.Param(Y, None)], + None, relay.multiply(X, Y), [] ), @@ -383,10 +376,8 @@ def test_call(): relay.Let( _, relay.Call(curried_mult, [to_constant(0)], None, None), - relay.Call(relay.Call(curried_mult, [to_constant(0)], None, None), [to_constant(0)], None, None), - TYPE_HOLE - ), - TYPE_HOLE + relay.Call(relay.Call(curried_mult, [to_constant(0)], None, None), [to_constant(0)], None, None) + ) ) ) From 20f1addaa08bbf313e201c38b13e7accbb4ebb82 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 15 Oct 2018 17:13:25 -0700 Subject: [PATCH 18/64] refactor for new commits --- python/tvm/relay/grammar/Relay.g4 | 15 ++++---- python/tvm/relay/parser.py | 38 ++++++++++---------- tests/python/relay/test_ir_parser.py | 54 ++++++++++++++-------------- 3 files changed, 55 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 1b252cd43c75..0653695aa0bd 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -20,7 +20,7 @@ NE: '!=' ; opIdent: CNAME ; GLOBAL_VAR: '@' CNAME ; -VAR: '%' CNAME ; +LOCAL_VAR: '%' CNAME ; MUT: 'mut' ; @@ -72,7 +72,7 @@ expr | 'if' '(' expr ')' body 'else' body # ifElse // sequencing - | 'let' MUT? ident (':' type_)? '=' expr ';' expr # seq + | 'let' MUT? var '=' expr ';' expr # seq // sugar for let %_ = expr; expr | expr ';' expr # seq // sugar for let %_ = expr; expr @@ -88,14 +88,13 @@ expr // | 'debug' # debug ; -func: 'fn' paramList '->' type_? body ; -defn: 'def' ident paramList '->' type_? body ; +func: 'fn' varList '->' type_? body ; +defn: 'def' ident varList '->' type_? body ; -paramList: '(' (param (',' param)*)? ')' ; -param: ident (':' type_)? ; +varList: '(' (var (',' var)*)? ')' ; +var: ident (':' type_)? ; // TODO(@jmp): for improved type annotations -// typeAnno: ':' type_ ; // returnAnno: (ident ':')? type_ ; // relations: 'where' relation (',' relation)* ; @@ -134,5 +133,5 @@ scalar ident : opIdent | GLOBAL_VAR - | VAR + | LOCAL_VAR ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index cdb50ba5c479..4c5c6ea6e5f7 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -2,9 +2,10 @@ from antlr4 import ParserRuleContext, InputStream, CommonTokenStream from antlr4.tree.Tree import TerminalNode from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any import tvm from tvm import relay +from relay.ir_builder import convert import sys if sys.version_info.major < 3: from .grammar.py2.RelayVisitor import RelayVisitor @@ -151,9 +152,9 @@ def exit_var_scope(self): # type: () -> Scope[relay.Var] return self.var_scopes.popleft() - def mk_var(self, name): - # type: (str) -> relay.Var - var = relay.Var(name) + def mk_var(self, name, type_): + # type: (str, relay.Type) -> relay.Var + var = relay.Var(name, type_) self.var_scopes[0].appendleft((name, var)) return var @@ -180,7 +181,7 @@ def visitTerminal(self, node): # variables if node_type == RelayLexer.GLOBAL_VAR: return relay.GlobalVar(node.getText()[1:]) - elif node_type == RelayLexer.VAR: + elif node_type == RelayLexer.LOCAL_VAR: name = node.getText()[1:] var = lookup(self.var_scopes, name) if var is None: @@ -272,22 +273,24 @@ def visitSeq(self, ctx): if ctx.ident() is None: # anonymous identity - ident = self.mk_var("_") + ident = "_" else: - ident = ctx.ident().VAR() - if ident is None: + local_var = ctx.ident().LOCAL_VAR() + if local_var is None: raise ParseError('Only local ids may be used in `let`s.') - ident = self.mk_var(ident.getText()[1:]) + ident = local_var.getText()[1:] type_ = self.getType_(ctx.type_()) + var = self.mk_var(ident, type_) + self.enter_var_scope() value = self.visit(ctx.expr(0)) self.exit_var_scope() body = self.visit(ctx.expr(1)) - return relay.Let(ident, value, body, type_) + return relay.Let(var, value, body) def visitBinOp(self, ctx): # type: (RelayParser.BinOpContext) -> relay.Call @@ -300,21 +303,20 @@ def visitBinOp(self, ctx): return relay_op(arg0, arg1) - def visitParam(self, ctx): - # type: (RelayParser.ParamContext) -> relay.Param - ident = ctx.ident().VAR() + def visitVar(self, ctx): + # type: (RelayParser.VarContext) -> relay.Var + ident = ctx.ident().LOCAL_VAR() if ident is None: raise ParseError('Only local ids may be used in params.') - ident = self.mk_var(ident.getText()[1:]) type_ = self.getType_(ctx.type_()) - return relay.Param(ident, type_) + return self.mk_var(ident.getText()[1:], type_) - def visitParamList(self, ctx): - # type: (RelayParser.ParamListContext) -> List[relay.Param] - return self.visit_list(ctx.param()) + def visitVarList(self, ctx): + # type: (RelayParser.VarListContext) -> List[relay.Var] + return self.visit_list(ctx.var()) def mk_func(self, ctx): # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> relay.Function diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index a79135862cb0..7ff251505a61 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -2,6 +2,7 @@ from tvm import relay from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program from tvm.relay.ir_pass import alpha_equal +# from tvm.relay.ir_builder import convert from tvm.relay.expr import pretty_print from nose.tools import nottest, raises from typing import Union @@ -56,11 +57,13 @@ def to_tensor_type(x): # type: (str) -> relay.TensorType return relay.TensorType([], x) +int32 = to_tensor_type("int32") + _ = relay.Var("_") X = relay.Var("x") Y = relay.Var("y") - -int64 = to_tensor_type("int64") +X_ANNO = relay.Var("x", int32) +Y_ANNO = relay.Var("y", int32) UNIT = relay.Tuple([]) @@ -204,7 +207,7 @@ def test_func(): assert alpha_equal( parse_expr("fn (%x) -> { %x }"), relay.Function( - [relay.Param(X, None)], + [X], None, X, [] @@ -215,8 +218,7 @@ def test_func(): assert alpha_equal( parse_expr("fn (%x, %y) -> { %x + %y }"), relay.Function( - [relay.Param(X, None), - relay.Param(Y, None)], + [X, Y], None, relay.add(X, Y), [] @@ -225,10 +227,10 @@ def test_func(): # annotations assert alpha_equal( - parse_expr("fn (%x: Int64) -> Int64 { %x }"), + parse_expr("fn (%x: Int32) -> Int32 { %x }"), relay.Function( - [relay.Param(X, int64)], - int64, + [X_ANNO], + int32, X, [] ) @@ -299,7 +301,7 @@ def test_call(): ), relay.Let( id_var, - relay.Function([relay.Param(X, None)], None, X, []), + relay.Function([X], None, X, []), relay.Call(id_var, [to_constant(1)], None, None) ) ) @@ -316,7 +318,7 @@ def test_call(): relay.Let( multiply, relay.Function( - [relay.Param(X, None), relay.Param(Y, None)], + [X, Y], None, relay.multiply(X, Y), [] @@ -334,7 +336,7 @@ def test_call(): ), relay.Call( relay.Function( - [relay.Param(X, None)], + [X], None, X, [] @@ -363,10 +365,10 @@ def test_call(): relay.Let( curried_mult, relay.Function( - [relay.Param(X, None)], + [X], None, relay.Function( - [relay.Param(Y, None)], + [Y], None, relay.multiply(X, Y), [] @@ -401,42 +403,42 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: () -> Int64 = fn () -> Int64 { 0 }; () + let %_: () -> Int32 = fn () -> Int32 { 0 }; () """ ), relay.Let( _, - relay.Function([], int64, to_constant(0), []), + relay.Function([], int32, to_constant(0), []), UNIT, - relay.FuncType([], int64, [], []) + relay.FuncType([], int32, [], []) ) ) assert alpha_equal( parse_expr( """ - let %_: (Int64) -> Int64 = fn (%x: Int64) -> Int64 { 0 }; () + let %_: (Int32) -> Int32 = fn (%x: Int32) -> Int32 { 0 }; () """ ), relay.Let( _, - relay.Function([relay.Param(X, int64)], int64, to_constant(0), []), + relay.Function([relay.Var("x", int32)], int32, to_constant(0), []), UNIT, - relay.FuncType([int64], int64, [], []) + relay.FuncType([int32], int32, [], []) ) ) assert alpha_equal( parse_expr( """ - let %_: (Int64, Int64) -> Int64 = fn (%x: Int64, %y: Int64) -> Int64 { 0 }; () + let %_: (Int32, Int32) -> Int32 = fn (%x: Int32, %y: Int32) -> Int32 { 0 }; () """ ), relay.Let( _, - relay.Function([relay.Param(X, int64), relay.Param(Y, int64)], int64, to_constant(0), []), + relay.Function([relay.Var("x", int32), relay.Var("y", int32)], int32, to_constant(0), []), UNIT, - relay.FuncType([int64, int64], int64, [], []) + relay.FuncType([int32, int32], int32, [], []) ) ) @@ -457,25 +459,25 @@ def test_tuple_type(): assert alpha_equal( parse_expr( """ - let %x: (Int64,) = (0,); () + let %x: (Int32,) = (0,); () """), relay.Let( X, relay.Tuple([to_constant(0)]), UNIT, - relay.TupleType([int64]) + relay.TupleType([int32]) ) ) assert alpha_equal( parse_expr( """ - let %x: (Int64, Int64) = (0, 1); () + let %x: (Int32, Int32) = (0, 1); () """), relay.Let( X, relay.Tuple([to_constant(0), to_constant(1)]), UNIT, - relay.TupleType([int64, int64]) + relay.TupleType([int32, int32]) ) ) From 0657f4c52b91c708eb9fa1d8093fffaebe8edd2a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 15 Oct 2018 20:13:39 -0700 Subject: [PATCH 19/64] restore tests --- python/tvm/relay/parser.py | 13 +++++----- tests/python/relay/test_ir_parser.py | 36 ++++++++++++---------------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 4c5c6ea6e5f7..5dbbc996c74e 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -5,7 +5,6 @@ from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any import tvm from tvm import relay -from relay.ir_builder import convert import sys if sys.version_info.major < 3: from .grammar.py2.RelayVisitor import RelayVisitor @@ -271,16 +270,16 @@ def visitSeq(self, ctx): if ctx.MUT() is not None: raise ParseError("Mutation is currently unsupported.") - if ctx.ident() is None: + if ctx.var() is None or ctx.var().ident() is None: # anonymous identity ident = "_" + type_ = None else: - local_var = ctx.ident().LOCAL_VAR() + local_var = ctx.var().ident().LOCAL_VAR() if local_var is None: raise ParseError('Only local ids may be used in `let`s.') ident = local_var.getText()[1:] - - type_ = self.getType_(ctx.type_()) + type_ = self.getType_(ctx.var().type_()) var = self.mk_var(ident, type_) @@ -325,7 +324,7 @@ def mk_func(self, ctx): self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() - param_list = self.visit(ctx.paramList()) + var_list = self.visit(ctx.varList()) ret_type = self.getType_(ctx.type_()) type_params = list(self.exit_type_param_scope()) @@ -335,7 +334,7 @@ def mk_func(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return relay.Function(param_list, ret_type, body, type_params) # type: ignore + return relay.Function(var_list, ret_type, body, type_params) # type: ignore def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> relay.Function diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7ff251505a61..c88414c535ea 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -231,7 +231,7 @@ def test_func(): relay.Function( [X_ANNO], int32, - X, + X_ANNO, [] ) ) @@ -407,10 +407,9 @@ def test_function_type(): """ ), relay.Let( - _, + relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], int32, to_constant(0), []), - UNIT, - relay.FuncType([], int32, [], []) + UNIT ) ) @@ -421,10 +420,9 @@ def test_function_type(): """ ), relay.Let( - _, + relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], int32, to_constant(0), []), - UNIT, - relay.FuncType([int32], int32, [], []) + UNIT ) ) @@ -435,10 +433,9 @@ def test_function_type(): """ ), relay.Let( - _, + relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], int32, to_constant(0), []), - UNIT, - relay.FuncType([int32, int32], int32, [], []) + UNIT ) ) @@ -449,35 +446,32 @@ def test_tuple_type(): let %_: () = (); () """), relay.Let( - _, - UNIT, + relay.Var("_", relay.TupleType([])), UNIT, - relay.TupleType([]) + UNIT ) ) assert alpha_equal( parse_expr( """ - let %x: (Int32,) = (0,); () + let %_: (Int32,) = (0,); () """), relay.Let( - X, + relay.Var("_", relay.TupleType([int32])), relay.Tuple([to_constant(0)]), - UNIT, - relay.TupleType([int32]) + UNIT ) ) assert alpha_equal( parse_expr( """ - let %x: (Int32, Int32) = (0, 1); () + let %_: (Int32, Int32) = (0, 1); () """), relay.Let( - X, + relay.Var("_", relay.TupleType([int32, int32])), relay.Tuple([to_constant(0), to_constant(1)]), - UNIT, - relay.TupleType([int32, int32]) + UNIT ) ) From ba61579afa7ee381824526202b984a5d657dae9d Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 15 Oct 2018 20:18:16 -0700 Subject: [PATCH 20/64] restore def test --- tests/python/relay/test_ir_parser.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index c88414c535ea..196c20ad9bde 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -236,10 +236,17 @@ def test_func(): ) ) -# TODO(@jmp): Figure out why this is crashing. -@nottest +# TODO(@jmp): Crashes if %x isn't annnotated. +# @nottest def test_defn(): - id_defn = parse_prog("def @id(%x) -> { %x }") + id_defn = parse_prog( + """ + def @id(%x: Int32) -> Int32 { + %x + } + + () + """) assert isinstance(id_defn, Program) def test_ifelse(): From 445a622077994dd50bdc7bc312f1af65319f5172 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 15 Oct 2018 21:06:04 -0700 Subject: [PATCH 21/64] python linting --- python/tvm/relay/parser.py | 91 +++++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 5dbbc996c74e..50f4aaa6ca2e 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,11 +1,12 @@ +# pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" -from antlr4 import ParserRuleContext, InputStream, CommonTokenStream -from antlr4.tree.Tree import TerminalNode from collections import deque +import sys from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any +from antlr4 import ParserRuleContext, InputStream, CommonTokenStream +from antlr4.tree.Tree import TerminalNode import tvm from tvm import relay -import sys if sys.version_info.major < 3: from .grammar.py2.RelayVisitor import RelayVisitor from .grammar.py2.RelayParser import RelayParser @@ -18,6 +19,8 @@ Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) class ParseError(Exception): + """Exception type for parse errors.""" + def __init__(self, message): # type: (str) -> None super(ParseError, self).__init__() @@ -56,6 +59,8 @@ def __init__(self, message): def int_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType + """Turn an Int type call into a Relay TensorType""" + if len(args) > 2: raise ParseError("Int may have at most 2 arguments.") @@ -65,6 +70,8 @@ def int_type_call(args): def uint_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType + """Turn a UInt type call into a Relay TensorType""" + if len(args) > 2: raise ParseError("UInt may have at most 2 arguments.") @@ -74,6 +81,8 @@ def uint_type_call(args): def float_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType + """Turn a Float type call into a Relay TensorType""" + if len(args) > 2: raise ParseError("Float may have at most 2 arguments.") @@ -83,6 +92,8 @@ def float_type_call(args): def bool_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType + """Turn a Bool type call into a Relay TensorType""" + if len(args) > 1: raise ParseError("Bool may have at most 1 argument.") @@ -95,9 +106,11 @@ def bool_type_call(args): # TODO(@jmp): Unused. def tensor_type_call(args): # type: (List[relay.Expr]) -> Union[relay.Expr, relay.TensorType] + """Turn a Tensor type call into a Relay TensorType""" + if len(args) > 2: raise ParseError("Tensor may have at most 2 arguments.") - elif len(args) == 0: + elif not args: print("Warning. Generic Tensor type unimplemented. Treating as Expr.") return relay.Expr() elif len(args) == 1: @@ -122,9 +135,11 @@ def tensor_type_call(args): def lookup(scopes, name): # type: (Scopes[T], str) -> Optional[T] + """Look up `name` in `scopes`.""" + for scope in scopes: - for n, val in scope: - if n == name: + for key, val in scope: + if key == name: return val return None @@ -145,28 +160,40 @@ def __init__(self): def enter_var_scope(self): # type: () -> None + """Enter a new Var scope so it can be popped off later.""" + self.var_scopes.appendleft(deque()) def exit_var_scope(self): # type: () -> Scope[relay.Var] + """Pop off the current Var scope and return it.""" + return self.var_scopes.popleft() def mk_var(self, name, type_): # type: (str, relay.Type) -> relay.Var + """Create a new Var and add it to the Var scope.""" + var = relay.Var(name, type_) self.var_scopes[0].appendleft((name, var)) return var def enter_type_param_scope(self): # type: () -> None + """Enter a new TypeParam scope so it can be popped off later.""" + self.type_param_scopes.appendleft(deque()) def exit_type_param_scope(self): # type: () -> Scope[relay.TypeParam] + """Pop off the current TypeParam scope and return it.""" + return self.type_param_scopes.popleft() def mk_typ(self, name, kind): # (str, relay.Kind) -> relay.TypeParam + """Create a new TypeParam and add it to the TypeParam scope.""" + typ = relay.TypeParam(name, kind) self.type_param_scopes[0].appendleft((name, typ)) return typ @@ -176,44 +203,49 @@ def visitTerminal(self, node): """Visit lexer tokens that aren't ignored or visited by other functions.""" node_type = node.getSymbol().type + node_text = node.getText() # variables if node_type == RelayLexer.GLOBAL_VAR: - return relay.GlobalVar(node.getText()[1:]) + return relay.GlobalVar(node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: - name = node.getText()[1:] + name = node_text[1:] var = lookup(self.var_scopes, name) if var is None: raise ParseError("Couldn't resolve `{}`.".format(name)) - else: - return var + + return var # data types elif node_type == RelayLexer.INT: - return int(node.getText()) + return int(node_text) elif node_type == RelayLexer.FLOAT: - return float(node.getText()) + return float(node_text) elif node_type == RelayLexer.BOOL_LIT: - if node.getText() == "true": + if node_text == "true": return True - elif node.getText() == "false": + elif node_text == "false": return False else: - assert False + raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) else: - raise ParseError("todo: {}".format(node.getText())) + raise ParseError("todo: {}".format(node_text)) def visit_list(self, ctx_list): # type: (List[ParserRuleContext]) -> List[relay.Expr] + """"Visit a list of contexts.""" + return [self.visit(ctx) for ctx in ctx_list] def getType_(self, ctx): # type: (Optional[RelayParser.Type_Context]) -> Optional[relay.Type] + """Return a (possibly None) Relay type.""" + if ctx is None: return None - else: - return self.visit(ctx) + + return self.visit(ctx) def visitProg(self, ctx): # type: (RelayParser.ProgContext) -> Program @@ -256,8 +288,8 @@ def visitNeg(self, ctx): if isinstance(val, relay.Constant) and val.data.asnumpy().ndim == 0: # fold Neg in for scalars return relay.Constant(tvm.nd.array(-val.data.asnumpy().item())) - else: - return relay.negative(val) + + return relay.negative(val) def visitTuple(self, ctx): # type: (RelayParser.TupleContext) -> relay.Tuple @@ -286,7 +318,7 @@ def visitSeq(self, ctx): self.enter_var_scope() value = self.visit(ctx.expr(0)) self.exit_var_scope() - + body = self.visit(ctx.expr(1)) return relay.Let(var, value, body) @@ -319,6 +351,7 @@ def visitVarList(self, ctx): def mk_func(self, ctx): # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> relay.Function + """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() @@ -419,6 +452,8 @@ def visitFuncType(self, ctx): def make_parser(data): # type: (str) -> RelayParser + """Construct a RelayParser a given data stream.""" + input_stream = InputStream(data) lexer = RelayLexer(input_stream) token_stream = CommonTokenStream(lexer) @@ -428,25 +463,19 @@ def parse_expr(data): # type: (str) -> relay.Expr """Parse a Relay expression.""" - # try: - # TODO add error handling here tree = make_parser(data).expr() return ParseTreeToRelayIR().visit(tree) - # except Exception as exn: - # raise ParseError("parser error: {}".format(exn)) def parse_prog(data): # type: (str) -> Program """Parse a Relay program.""" - # try: - # TODO add error handling here tree = make_parser(data).prog() return ParseTreeToRelayIR().visit(tree) - # except Exception as exn: - # raise ParseError("parser error: {}".format(exn)) def parse_file(path): # type: (str) -> Program - with open(path, 'r') as f: - return parse_prog(f.read()) + """Parse a Relay program from a file.""" + + with open(path, 'r') as in_file: + return parse_prog(in_file.read()) From 080df4208f8f13bdeb58375deb5bb2e8718bbe4e Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 18 Oct 2018 14:53:19 -0700 Subject: [PATCH 22/64] change sequence parsing --- python/tvm/relay/grammar/Relay.g4 | 7 +++---- tests/python/relay/test_ir_parser.py | 16 +++++----------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 0653695aa0bd..6e7370a124d9 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -72,11 +72,10 @@ expr | 'if' '(' expr ')' body 'else' body # ifElse // sequencing - | 'let' MUT? var '=' expr ';' expr # seq + | 'let' MUT? var '=' expr ';' expr # seq + | 'let' MUT? var '=' '{' expr '}' ';' expr # seq // sugar for let %_ = expr; expr - | expr ';' expr # seq - // sugar for let %_ = expr; expr - | '{' expr '}' ';' expr # seq + | expr ';' expr # seq // mutable update // | ident '=' expr # writeRef diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 196c20ad9bde..168a3b01a87b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -143,7 +143,6 @@ def test_vars(): def test_let(): assert alpha_equal( parse_expr("let %x = 1; ()"), - relay.Let( X, to_constant(1), @@ -154,7 +153,6 @@ def test_let(): def test_seq(): assert alpha_equal( parse_expr("(); ()"), - relay.Let( _, UNIT, @@ -162,18 +160,14 @@ def test_seq(): ) assert alpha_equal( - parse_expr("{ (); () }; ()"), - # Can't use _ constant, because the _'s are different. + parse_expr("let %_ = { 1 }; ()"), relay.Let( - relay.Var("_"), - relay.Let(relay.Var("_"), UNIT, UNIT), - UNIT) + X, + to_constant(1), + UNIT + ) ) -@raises(ParseError) -def test_seq_scope(): - parse_expr("{ let %x = 1; %x }; %x") - @raises(ParseError) def test_let_global_var(): parse_expr("let @x = 1; ()") From a12bc173c5ac35b61691773327f1d24247f0defe Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 18 Oct 2018 15:16:21 -0700 Subject: [PATCH 23/64] remove compiler options --- python/tvm/relay/grammar/Relay.g4 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 6e7370a124d9..183441663cc9 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -46,9 +46,9 @@ fragment DIGIT: [0-9] ; // Parsing // a program is a list of options, a list of global definitions, and an expression -prog: option* defn* expr EOF ; +prog: /* option* */ defn* expr EOF ; -option: 'set' ident BOOL_LIT ; +// option: 'set' ident BOOL_LIT ; expr // operators From 9a34108ebf642882d95113d629d412a0b0d5502b Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 18 Oct 2018 21:24:48 -0700 Subject: [PATCH 24/64] attempt at tensor + shape types --- python/tvm/relay/grammar/Relay.g4 | 20 +++++---- python/tvm/relay/parser.py | 63 +++++++++++++++------------- tests/python/relay/test_ir_parser.py | 21 ++++++++-- 3 files changed, 65 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 183441663cc9..a0664f9e6d54 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -46,7 +46,7 @@ fragment DIGIT: [0-9] ; // Parsing // a program is a list of options, a list of global definitions, and an expression -prog: /* option* */ defn* expr EOF ; +prog: /* option* */ defn* /* expr */ EOF ; // option: 'set' ident BOOL_LIT ; @@ -100,19 +100,25 @@ var: ident (':' type_)? ; // relation: ident '(' (type_ (',' type_)*)? ')' ; type_ - // TODO(@jmp): for shape expressions - // : '(' type_ ')' # parensType - // | type_ op=('*'|'/') type_ # binOpType - // | type_ op=('+'|'-') type_ # binOpType : '(' ')' # tupleType | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType | identType # identTypeType + | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType | identType '[' (type_ (',' type_)*)? ']' # callType + // Mut, Int, UInt, Float, Bool | '(' (type_ (',' type_)*)? ')' '->' type_ # funcType - // Mut, Int, UInt, Float, Bool, Tensor + | '_' # incompleteType | INT # intType - // | '_' # incompleteType + ; + +shapeSeq: '(' (shape (',' shape)*)? ')' ; + +shape + : '(' shape ')' # parensShape + // | type_ op=('*'|'/') type_ # binOpType + // | type_ op=('+'|'-') type_ # binOpType + | INT # intShape ; identType: CNAME ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 50f4aaa6ca2e..7c16288f52b1 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -16,8 +16,6 @@ from .grammar.py3.RelayParser import RelayParser from .grammar.py3.RelayLexer import RelayLexer -Program = NamedTuple("Program", [("ast", relay.Expr), ("env", relay.Environment)]) - class ParseError(Exception): """Exception type for parse errors.""" @@ -103,30 +101,11 @@ def bool_type_call(args): return relay.TensorType([], "uint1x" + str_args[0]) -# TODO(@jmp): Unused. -def tensor_type_call(args): - # type: (List[relay.Expr]) -> Union[relay.Expr, relay.TensorType] - """Turn a Tensor type call into a Relay TensorType""" - - if len(args) > 2: - raise ParseError("Tensor may have at most 2 arguments.") - elif not args: - print("Warning. Generic Tensor type unimplemented. Treating as Expr.") - return relay.Expr() - elif len(args) == 1: - raise ParseError("Generic Shape type unimplemented.") - elif len(args) == 2: - dtype = args[0] - shape = args[1] - - return relay.TensorType(shape, dtype) - TYPE_FUNCS = { "Int": int_type_call, "UInt": uint_type_call, "Float": float_type_call, "Bool": bool_type_call, - # "Tensor": tensor_type_call, } T = TypeVar("T") @@ -233,7 +212,7 @@ def visitTerminal(self, node): raise ParseError("todo: {}".format(node_text)) def visit_list(self, ctx_list): - # type: (List[ParserRuleContext]) -> List[relay.Expr] + # type: (List[ParserRuleContext]) -> List[Any] """"Visit a list of contexts.""" return [self.visit(ctx) for ctx in ctx_list] @@ -248,18 +227,21 @@ def getType_(self, ctx): return self.visit(ctx) def visitProg(self, ctx): - # type: (RelayParser.ProgContext) -> Program - if ctx.option(): - raise ParseError("Compiler options are unimplemented.") + # type: (RelayParser.ProgContext) -> relay.Environment + # if ctx.option(): + # raise ParseError("Compiler options are unimplemented.") self.visit_list(ctx.defn()) - expr = self.visit(ctx.expr()) - - return Program(ast=expr, env=self.env) + return self.env # Exprs + def visitOpIdent(self, ctx): + # type: (RelayParser.OpIdentContext) -> relay.Op + + return relay.op.get(ctx.CNAME()) + # pass through def visitParens(self, ctx): # type: (RelayParser.ParensContext) -> relay.Expr @@ -407,6 +389,10 @@ def visitIfElse(self, ctx): # Types + def visitIncompleteType(self, ctx): + # type (RelayParser.IncompleteTypeContext) -> None: + return None + def visitIdentType(self, ctx): # type: (RelayParser.IdentTypeContext) -> Union[relay.TensorType, str] ident_type = ctx.CNAME().getText() @@ -437,6 +423,27 @@ def visitCallType(self, ctx): else: return func_type + def visitParensShape(self, ctx): + # type: (RelayParser.ParensShapeContext) -> int + return self.visit(ctx.shape()) + + def visitShapeSeq(self, ctx): + # type: (RelayParser.ShapeSeqContext) -> List[int] + return self.visit_list(ctx.shape()) + + def visitTensorType(self, ctx): + # type: (RelayParser.TensorTypeContext) -> relay.TensorType + + shape = self.visit(ctx.shapeSeq()) + dtype = self.visit(ctx.type_) + + if not isinstance(dtype, relay.TensorType): + raise ParseError("Expected dtype to be a Relay base type.") + + dtype = dtype.dtype + + return relay.TensorType(shape, dtype) + def visitTupleType(self, ctx): # type: (RelayParser.TupleTypeContext) -> relay.TupleType return relay.TupleType(self.visit_list(ctx.type_())) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 168a3b01a87b..22d546c024a7 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,6 +1,6 @@ import tvm from tvm import relay -from tvm.relay.parser import parse_expr, parse_prog, ParseError, Program +from tvm.relay.parser import parse_expr, parse_prog, ParseError from tvm.relay.ir_pass import alpha_equal # from tvm.relay.ir_builder import convert from tvm.relay.expr import pretty_print @@ -238,10 +238,8 @@ def test_defn(): def @id(%x: Int32) -> Int32 { %x } - - () """) - assert isinstance(id_defn, Program) + assert isinstance(id_defn, relay.Environment) def test_ifelse(): assert alpha_equal( @@ -386,6 +384,16 @@ def test_call(): # Types +def test_incomplete_type(): + assert alpha_equal( + parse_expr("let %_ : _ = (); ()"), + relay.Let( + _, + UNIT, + UNIT + ) + ) + def test_builtin_types(): for builtin_type in TYPES: parse_expr("let %_ : {} = (); ()".format(builtin_type)) @@ -400,6 +408,11 @@ def test_call_type(): args.append(i) parse_expr("let %_ : {}{} = (); ()".format(call_type, args)) + # Tensors + parse_expr("let %_ : Tensor[(1), Float32] = (); ()") + parse_expr("let %_ : Tensor[(1, 1), Float32] = (); ()") + parse_expr("let %_ : Tensor[(1, 1, 1), Float32] = (); ()") + def test_function_type(): assert alpha_equal( parse_expr( From a20a68d6ddb57e673603ba2715b9d11f8cb3f812 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 18 Oct 2018 21:47:08 -0700 Subject: [PATCH 25/64] op test --- python/tvm/relay/parser.py | 2 +- tests/python/relay/test_ir_parser.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 7c16288f52b1..24d918fc2b6a 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -240,7 +240,7 @@ def visitProg(self, ctx): def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> relay.Op - return relay.op.get(ctx.CNAME()) + return relay.op.get(ctx.CNAME().getText()) # pass through def visitParens(self, ctx): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 22d546c024a7..8260aeaa0aa6 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -382,6 +382,12 @@ def test_call(): ) ) + # op + alpha_equal( + parse_expr("abs(1)"), + relay.Call(relay.op.get("abs"), [to_constant(1)], None, None) + ) + # Types def test_incomplete_type(): From 32f644b2e85e9eebd0018303ead1a61ab0613db9 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 18 Oct 2018 21:55:49 -0700 Subject: [PATCH 26/64] tensor type --- python/tvm/relay/grammar/Relay.g4 | 6 +++++- python/tvm/relay/parser.py | 13 ++++++------ tests/python/relay/test_ir_parser.py | 31 ++++++++++++++++++++++++---- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index a0664f9e6d54..5214bd729ab9 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -112,7 +112,11 @@ type_ | INT # intType ; -shapeSeq: '(' (shape (',' shape)*)? ')' ; +shapeSeq + : '(' ')' + | '(' shape ',' ')' + | '(' shape (',' shape)+ ')' + ; shape : '(' shape ')' # parensShape diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 24d918fc2b6a..5d4259e0355e 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -64,7 +64,7 @@ def int_type_call(args): str_args = [str(arg) for arg in args] - return relay.TensorType([], "int" + "x".join(str_args)) + return relay.TensorType((), "int" + "x".join(str_args)) def uint_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType @@ -75,7 +75,7 @@ def uint_type_call(args): str_args = [str(arg) for arg in args] - return relay.TensorType([], "uint" + "x".join(str_args)) + return relay.TensorType((), "uint" + "x".join(str_args)) def float_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType @@ -86,7 +86,7 @@ def float_type_call(args): str_args = [str(arg) for arg in args] - return relay.TensorType([], "float" + "x".join(str_args)) + return relay.TensorType((), "float" + "x".join(str_args)) def bool_type_call(args): # type: (List[relay.Expr]) -> relay.TensorType @@ -99,7 +99,7 @@ def bool_type_call(args): # bool is sugar for uint1 anyway str_args = [str(arg) for arg in args] - return relay.TensorType([], "uint1x" + str_args[0]) + return relay.TensorType((), "uint1x" + str_args[0]) TYPE_FUNCS = { "Int": int_type_call, @@ -239,7 +239,6 @@ def visitProg(self, ctx): def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> relay.Op - return relay.op.get(ctx.CNAME().getText()) # pass through @@ -405,7 +404,7 @@ def visitIdentType(self, ctx): if builtin_type is None: raise ParseError("Unknown builtin type: {}".format(ident_type)) else: - return relay.TensorType([], builtin_type) + return relay.TensorType((), builtin_type) def visitCallType(self, ctx): # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] @@ -435,7 +434,7 @@ def visitTensorType(self, ctx): # type: (RelayParser.TensorTypeContext) -> relay.TensorType shape = self.visit(ctx.shapeSeq()) - dtype = self.visit(ctx.type_) + dtype = self.visit(ctx.type_()) if not isinstance(dtype, relay.TensorType): raise ParseError("Expected dtype to be a Relay base type.") diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 8260aeaa0aa6..fd7f6e6fd232 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -414,10 +414,33 @@ def test_call_type(): args.append(i) parse_expr("let %_ : {}{} = (); ()".format(call_type, args)) - # Tensors - parse_expr("let %_ : Tensor[(1), Float32] = (); ()") - parse_expr("let %_ : Tensor[(1, 1), Float32] = (); ()") - parse_expr("let %_ : Tensor[(1, 1, 1), Float32] = (); ()") +def test_tensor_type(): + assert alpha_equal( + parse_expr("let %_ : Tensor[(), Float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((), "float32")), + UNIT, + UNIT + ) + ) + + assert alpha_equal( + parse_expr("let %_ : Tensor[(1,), Float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((1,), "float32")), + UNIT, + UNIT + ) + ) + + assert alpha_equal( + parse_expr("let %_ : Tensor[(1, 1), Float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((1, 1), "float32")), + UNIT, + UNIT + ) + ) def test_function_type(): assert alpha_equal( From 7742e0c67080f8a182de9f7d813e16f1f190492f Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 19 Oct 2018 11:17:25 -0700 Subject: [PATCH 27/64] fix some rebase issues --- python/tvm/relay/parser.py | 10 +- src/relay/ir/pretty_printer.cc | 304 --------------------------------- 2 files changed, 5 insertions(+), 309 deletions(-) delete mode 100644 src/relay/ir/pretty_printer.cc diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 5d4259e0355e..b12580b7040e 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -253,23 +253,23 @@ def visitBody(self, ctx): def visitScalarFloat(self, ctx): # type: (RelayParser.ScalarFloatContext) -> relay.Constant - return relay.Constant(tvm.nd.array(self.visit(ctx.FLOAT()))) + return relay.const(self.visit(ctx.FLOAT())) def visitScalarInt(self, ctx): # type: (RelayParser.ScalarIntContext) -> relay.Constant - return relay.Constant(tvm.nd.array(self.visit(ctx.INT()))) + return relay.const(self.visit(ctx.INT())) def visitScalarBool(self, ctx): # type: (RelayParser.ScalarBoolContext) -> relay.Constant - return relay.Constant(tvm.nd.array(self.visit(ctx.BOOL_LIT()))) + return relay.const(self.visit(ctx.BOOL_LIT())) def visitNeg(self, ctx): # type: (RelayParser.NegContext) -> Union[relay.Constant, relay.Call] val = self.visit(ctx.expr()) if isinstance(val, relay.Constant) and val.data.asnumpy().ndim == 0: # fold Neg in for scalars - return relay.Constant(tvm.nd.array(-val.data.asnumpy().item())) - + return relay.const(-val.data.asnumpy().item()) + return relay.negative(val) def visitTuple(self, ctx): diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc deleted file mode 100644 index d60f1b0ad16d..000000000000 --- a/src/relay/ir/pretty_printer.cc +++ /dev/null @@ -1,304 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file pretty_printer.cc - * \brief A pretty printer for the Relay IR. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "../pass/type_functor.h" -#include "doc.h" - -namespace tvm { -namespace relay { - -using namespace tvm::runtime; - -Doc KindDocify(TypeParamNode::Kind k) { - switch (k) { - case TypeParamNode::kShapeVar: - return DocOfStr("ShapeVar"); - case TypeParamNode::kShape: - return DocOfStr("Shape"); - case TypeParamNode::kBaseType: - return DocOfStr("BaseType"); - case TypeParamNode::kType: - return DocOfStr("Type"); - default: - LOG(FATAL) << "unreachable code: case not handle in kind"; - throw; // log fatal throw but compiler doesnt know - } -} - -template -std::vector MapDocify(const tvm::Array& arr, const std::function& f) { - std::vector vec; - for (size_t i = 0; i < arr.size(); ++i) { - vec.push_back(f(arr[i])); - } - return vec; -} - -template, typename Eq = std::equal_to> -class Counter { - std::unordered_map cnt_; - - public: - Counter() = default; - Counter(const Counter&) = delete; - size_t operator()(const T& t) { - auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1; - cnt_[t] = v; - return v; - } -}; - -std::string Mangle(const std::string& str, size_t s) { - return str + "_" + std::to_string(s); - // return s == 0 ? str : str + "_" + std::to_string(s - 1); - // the above line look prettier but is dangerous: - // suppose we have x, x, x_0. mangling will give x, x_0, x_0! - // the save approach give x_0, x_1, x_0_1, and in fact never clash: - // stripping _([0-9]*) is invert of mangle under all circumstances. - // another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other. -} - -constexpr size_t indent = 2; - -struct TypeParamName { - bool operator==(const TypeParamName&) const { - return true; - } -}; - -struct mhash { - size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept { - return 0; - } -}; - -class TypeDocifier : private TypeFunctor { - Environment env; - Counter cnt; - std::unordered_map map; - - std::vector DocifyTypeArray(const tvm::Array& arr) { - return MapDocify(arr, [=](const Type& t) { return Docify(t); }); - } - - std::vector DocifyTypeParam(const tvm::Array& arr) { - return MapDocify(arr, [=](const TypeParam& tp) { - return Docify(tp); - }); - } - - std::vector DocifyTypeConstraint(const tvm::Array& arr) { - return MapDocify(arr, [=](const TypeConstraint& tc) { return Docify(tc); }); - } - - Doc VisitType_(const TensorTypeNode* t) final { - return DocOfStr("tensor"); - } - - Doc VisitType_(const TypeParamNode* p) final { - auto tp = GetRef(p); - if (map.count(tp) == 0) { - auto name = - DocOfStr(Mangle("tp", cnt(TypeParamName())) + - std::string(":")) + - KindDocify(p->kind); - map.insert(std::pair(tp, name)); - } - return map.at(tp); - } - - Doc Quantify(const tvm::Array& tp, const Doc& d) { - if (tp.size() == 0) { - return d; - } - return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d; - } - - Doc Constraint(const tvm::Array& tc, const Doc& d) { - if (tc.size() == 0) { - return d; - } - return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d; - } - - Doc VisitType_(const FuncTypeNode* f) final { - auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() + - DocOfStr("->") + Sep() + Docify(f->ret_type); - return Group(Quantify(f->type_params, - Constraint(f->type_constraints, inner))); - } - - Doc VisitType_(const TypeRelationNode* r) final { - return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")"); - } - - Doc VisitType_(const TupleTypeNode* t) final { - return Seq("<", DocifyTypeArray(t->fields), ">"); - } - - Doc VisitType_(const IncompleteTypeNode* i) final { - return DocOfStr("_"); - } - - public: - TypeDocifier(const Environment& env) : env(env) { } - - Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); } -}; - -class ExprDocifier : private ExprFunctor { - Environment env; - Counter cnt; - std::unordered_map map; - TypeDocifier td; - - std::string VarName(const Var& v) { - if (map.count(v) == 0) { - map.insert(std::pair(v, Mangle(v->name_hint, cnt(v->name_hint)))); - } - return map.at(v); - } - - Doc TypeAnnotation(const Doc& d, const Type& t) { - // test for t being null. probably shouldnt has null. should talk to jared. - if (!t.get() || t.as()) { - return d; - } else { - return d + DocOfStr(":") + td.Docify(t); - } - } - - std::vector DocifyExprArray(const tvm::Array& arr) { - std::vector vec; - for (size_t i = 0; i < arr.size(); ++i) { - vec.push_back(Docify(arr[i])); - } - return vec; - } - - std::vector DocifyParamArray(const tvm::Array& arr) { - std::vector vec; - for (Var param : arr) { - vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)), - param->type_annotation)); - } - return vec; - } - - Doc VisitExpr_(const ConstantNode* c) final { - return DocOfStr("some_constant"); - } - - Doc VisitExpr_(const TupleNode* t) final { - return Seq("<", DocifyExprArray(t->fields), ">"); - } - - Doc VisitExpr_(const VarNode* v) final { - return DocOfStr(VarName(GetRef(v))); - } - - Doc VisitExpr_(const GlobalVarNode* g) final { - return DocOfStr(g->name_hint); - } - - Doc VisitExpr_(const FunctionNode* f) final { - return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() + - DocOfStr("=>") + Sep() + - Block(indent, "{", Docify(f->body), "}")); - } - - Doc VisitExpr_(const CallNode* c) final { - return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">"); - } - - Doc VisitExpr_(const LetNode* l) final { - return Group(DocOfStr("let") + Sep() + - TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() + - DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() + - Docify(l->body)); - } - - Doc VisitExpr_(const IfNode* i) final { - return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() + - Block(indent, "{", Docify(i->true_branch), "}") + Sep() + - DocOfStr("else") + Sep() + - Block(indent, "{", Docify(i->false_branch), "}")); - } - - Doc VisitExpr_(const OpNode* o) final { - return DocOfStr(o->name); - } - - Doc VisitExpr_(const TupleGetItemNode* g) final { - return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index)); - } - - public: - ExprDocifier(const Environment& env) : env(env), td(env) { } - - Doc Docify(const Expr& e) { return (*this)(e); } -}; - -Doc DocOfExpr(const Environment& env, const Expr& expr) { - ExprDocifier d(env); - return d.Docify(expr); -} - -Doc DocOfType(const Environment& env, const Type& expr) { - TypeDocifier d(env); - return d.Docify(expr); -} - -RDoc ExprRDoc(const Environment& env, const Expr& expr) { - return Layout(DocOfExpr(env, expr)); -} - -RDoc TypeRDoc(const Environment& env, const Type& expr) { - return Layout(DocOfType(env, expr)); -} - -std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) { - return os << ExprRDoc(env, e); -} - -std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) { - return os << TypeRDoc(env, t); -} - -std::string PrintExpr(const Environment& env, const Expr& e) { - std::stringstream ss; - ss << ExprRDoc(env, e); - return ss.str(); -} - -std::string PrintType(const Environment& env, const Type& t) { - std::stringstream ss; - ss << TypeRDoc(env, t); - return ss.str(); -} - -TVM_REGISTER_API("relay._expr._pretty_print") -.set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[1]; - if (x.as()) { - *ret = PrintType(args[0], Downcast(x)); - } else { - *ret = PrintExpr(args[0], Downcast(x)); - } - }); - -} // namespace relay -} // namespace tvm From 720af91b4ce1ecd7588297092465993d5dee4d0a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 19 Oct 2018 11:41:26 -0700 Subject: [PATCH 28/64] account for change in relay.Function interface and use relay.const --- python/tvm/relay/parser.py | 2 +- tests/python/relay/test_ir_parser.py | 65 +++++++++++++--------------- 2 files changed, 31 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index b12580b7040e..9f69d29fe675 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -348,7 +348,7 @@ def mk_func(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return relay.Function(var_list, ret_type, body, type_params) # type: ignore + return relay.Function(var_list, body, ret_type, type_params) # type: ignore def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> relay.Function diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index fd7f6e6fd232..e53855a1e908 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -3,7 +3,6 @@ from tvm.relay.parser import parse_expr, parse_prog, ParseError from tvm.relay.ir_pass import alpha_equal # from tvm.relay.ir_builder import convert -from tvm.relay.expr import pretty_print from nose.tools import nottest, raises from typing import Union @@ -49,10 +48,6 @@ def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() -def to_constant(x): - # type: (Union[float, int, bool]) -> relay.Constant - return relay.Constant(tvm.nd.array(x)) - def to_tensor_type(x): # type: (str) -> relay.TensorType return relay.TensorType([], x) @@ -106,7 +101,7 @@ def test_bin_op(): for bin_op in BINARY_OPS.keys(): assert alpha_equal( parse_expr("1 {} 1".format(bin_op)), - BINARY_OPS.get(bin_op)(to_constant(1), to_constant(1)) + BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) def test_parens(): @@ -145,7 +140,7 @@ def test_let(): parse_expr("let %x = 1; ()"), relay.Let( X, - to_constant(1), + relay.const(1), UNIT ) ) @@ -163,7 +158,7 @@ def test_seq(): parse_expr("let %_ = { 1 }; ()"), relay.Let( X, - to_constant(1), + relay.const(1), UNIT ) ) @@ -179,11 +174,11 @@ def test_let_op(): def test_tuple(): assert alpha_equal(parse_expr("()"), relay.Tuple([])) - assert alpha_equal(parse_expr("(0,)"), relay.Tuple([to_constant(0)])) + assert alpha_equal(parse_expr("(0,)"), relay.Tuple([relay.const(0)])) - assert alpha_equal(parse_expr("(0, 1)"), relay.Tuple([to_constant(0), to_constant(1)])) + assert alpha_equal(parse_expr("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) - assert alpha_equal(parse_expr("(0, 1, 2)"), relay.Tuple([to_constant(0), to_constant(1), to_constant(2)])) + assert alpha_equal(parse_expr("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) def test_func(): # 0 args @@ -191,8 +186,8 @@ def test_func(): parse_expr("fn () -> { 0 }"), relay.Function( [], + relay.const(0), None, - to_constant(0), [] ) ) @@ -202,8 +197,8 @@ def test_func(): parse_expr("fn (%x) -> { %x }"), relay.Function( [X], - None, X, + None, [] ) ) @@ -213,8 +208,8 @@ def test_func(): parse_expr("fn (%x, %y) -> { %x + %y }"), relay.Function( [X, Y], - None, relay.add(X, Y), + None, [] ) ) @@ -224,8 +219,8 @@ def test_func(): parse_expr("fn (%x: Int32) -> Int32 { %x }"), relay.Function( [X_ANNO], - int32, X_ANNO, + int32, [] ) ) @@ -253,9 +248,9 @@ def test_ifelse(): """ ), relay.If( - to_constant(True), - to_constant(0), - to_constant(1) + relay.const(True), + relay.const(0), + relay.const(1) ) ) @@ -284,7 +279,7 @@ def test_call(): ), relay.Let( constant, - relay.Function([], None, to_constant(0), []), + relay.Function([], relay.const(0), None, []), relay.Call(constant, [], None, None) ) ) @@ -300,8 +295,8 @@ def test_call(): ), relay.Let( id_var, - relay.Function([X], None, X, []), - relay.Call(id_var, [to_constant(1)], None, None) + relay.Function([X], X, None, []), + relay.Call(id_var, [relay.const(1)], None, None) ) ) @@ -318,11 +313,11 @@ def test_call(): multiply, relay.Function( [X, Y], - None, relay.multiply(X, Y), + None, [] ), - relay.Call(multiply, [to_constant(0), to_constant(0)], None, None) + relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) ) ) @@ -336,11 +331,11 @@ def test_call(): relay.Call( relay.Function( [X], - None, X, + None, [] ), - [to_constant(0)], + [relay.const(0)], None, None ) @@ -365,19 +360,19 @@ def test_call(): curried_mult, relay.Function( [X], - None, relay.Function( [Y], - None, relay.multiply(X, Y), + None, [] ), + None, [] ), relay.Let( _, - relay.Call(curried_mult, [to_constant(0)], None, None), - relay.Call(relay.Call(curried_mult, [to_constant(0)], None, None), [to_constant(0)], None, None) + relay.Call(curried_mult, [relay.const(0)], None, None), + relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) ) ) ) @@ -385,7 +380,7 @@ def test_call(): # op alpha_equal( parse_expr("abs(1)"), - relay.Call(relay.op.get("abs"), [to_constant(1)], None, None) + relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) # Types @@ -451,7 +446,7 @@ def test_function_type(): ), relay.Let( relay.Var("_", relay.FuncType([], int32, [], [])), - relay.Function([], int32, to_constant(0), []), + relay.Function([], relay.const(0), int32, []), UNIT ) ) @@ -464,7 +459,7 @@ def test_function_type(): ), relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), - relay.Function([relay.Var("x", int32)], int32, to_constant(0), []), + relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), UNIT ) ) @@ -477,7 +472,7 @@ def test_function_type(): ), relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), - relay.Function([relay.Var("x", int32), relay.Var("y", int32)], int32, to_constant(0), []), + relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), UNIT ) ) @@ -502,7 +497,7 @@ def test_tuple_type(): """), relay.Let( relay.Var("_", relay.TupleType([int32])), - relay.Tuple([to_constant(0)]), + relay.Tuple([relay.const(0)]), UNIT ) ) @@ -514,7 +509,7 @@ def test_tuple_type(): """), relay.Let( relay.Var("_", relay.TupleType([int32, int32])), - relay.Tuple([to_constant(0), to_constant(1)]), + relay.Tuple([relay.const(0), relay.const(1)]), UNIT ) ) From 10696773d0d0c52b34883db0356dfcd396db41d2 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 19 Oct 2018 11:43:56 -0700 Subject: [PATCH 29/64] remove pretty printer test file --- tests/python/relay/test_ir_pretty_printer.py | 90 -------------------- 1 file changed, 90 deletions(-) delete mode 100644 tests/python/relay/test_ir_pretty_printer.py diff --git a/tests/python/relay/test_ir_pretty_printer.py b/tests/python/relay/test_ir_pretty_printer.py deleted file mode 100644 index f7dfe2708bf5..000000000000 --- a/tests/python/relay/test_ir_pretty_printer.py +++ /dev/null @@ -1,90 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.expr import pretty_print -from tvm.relay.ir_builder import IRBuilder - -ib = IRBuilder() - -def show(e): - r = pretty_print(ib.env, e) - assert r is not None - - -def test_constant(): - arr = tvm.nd.array(10) - const = relay.Constant(arr) - show(const) - # should print the array inside? - - -def test_tuple(): - fields = tvm.convert([]) - tup = relay.Tuple(fields) - show(tup) - - -def test_local_var(): - name_hint = 's' - lv = relay.Var(name_hint) - show(lv) - - -def test_dup_var(): - lv = relay.Var('s') - rv = relay.Var('s') - show(relay.Tuple([lv, rv])) - - -def test_large_dup_var(): - av = relay.Var('s') - bv = relay.Var('s') - cv = relay.Var('s') - show(relay.Tuple([av, bv, cv])) - - -def test_global_var(): - name_hint = 'g' - gv = relay.GlobalVar(name_hint) - gv.name_hint == name_hint - show(gv) - - -def test_function(): - param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([relay.Var(n) for n in param_names]) - ret_type = None - body = params[0] - type_params = tvm.convert([]) - fn = relay.Function(params, ret_type, body, type_params) - show(fn) - - - -def test_call(): - op = relay.Var('f') - arg_names = ['a', 'b', 'c', 'd'] - args = tvm.convert([relay.Var(n) for n in arg_names]) - call = relay.Call(op, args, None, None) - show(call) - - -def test_let(): - ty = relay.ty.TensorType((10, 20), 'float32') - lv = relay.Var('x', ty) - arr = tvm.nd.array(10) - value = relay.Constant(arr) - let = relay.Let(lv, value, lv) - show(let) - - -def test_if(): - cond = relay.Var('cond') - left = relay.Var('left') - right = relay.Var('right') - ife = relay.If(cond, left, right) - show(ife) - -def test_tuple_get_item(): - t = relay.Var('t') - g = relay.TupleGetItem(t, 0) - show(g) From 703673c2ef83b027bc66a133a78142559debcc47 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 21 Oct 2018 15:27:04 -0700 Subject: [PATCH 30/64] -> only when return type is present. fn for function types --- python/tvm/relay/grammar/Relay.g4 | 6 ++--- python/tvm/relay/parser.py | 2 +- tests/python/relay/test_ir_parser.py | 37 +++++++++++++++------------- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 5214bd729ab9..7516f89e7790 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -87,8 +87,8 @@ expr // | 'debug' # debug ; -func: 'fn' varList '->' type_? body ; -defn: 'def' ident varList '->' type_? body ; +func: 'fn' varList ('->' type_)? body ; +defn: 'def' ident varList ('->' type_)? body ; varList: '(' (var (',' var)*)? ')' ; var: ident (':' type_)? ; @@ -107,7 +107,7 @@ type_ | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType | identType '[' (type_ (',' type_)*)? ']' # callType // Mut, Int, UInt, Float, Bool - | '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType | INT # intType ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 9f69d29fe675..72e3e3458113 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -361,7 +361,7 @@ def visitDefn(self, ctx): raise ParseError('Only global ids may be used in `def`s.') ident = relay.GlobalVar(ident.getText()[1:]) - self.env.add(ident, self.mk_func(ctx)) + self.env[ident] = self.mk_func(ctx) def visitCall(self, ctx): # type: (RelayParser.CallContext) -> relay.Call diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index e53855a1e908..09c9643130c3 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -48,6 +48,9 @@ def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() +def is_close(x, y, precision=0.001): + return x - y < precision and y - x < precision + def to_tensor_type(x): # type: (str) -> relay.TensorType return relay.TensorType([], x) @@ -74,18 +77,18 @@ def test_int_literal(): def test_float_literal(): assert get_scalar(parse_expr("1.0")) == 1.0 - assert get_scalar(parse_expr("1.56667")) == 1.56667 + assert is_close(get_scalar(parse_expr("1.56667")), 1.56667) assert get_scalar(parse_expr("0.0")) == 0.0 assert get_scalar(parse_expr("-10.0")) == -10.0 # scientific notation - assert get_scalar(parse_expr("1e-1")) == 1e-1 + assert is_close(get_scalar(parse_expr("1e-1")), 1e-1) assert get_scalar(parse_expr("1e+1")) == 1e+1 - assert get_scalar(parse_expr("1E-1")) == 1E-1 + assert is_close(get_scalar(parse_expr("1E-1")), 1E-1) assert get_scalar(parse_expr("1E+1")) == 1E+1 - assert get_scalar(parse_expr("1.0e-1")) == 1.0e-1 + assert is_close(get_scalar(parse_expr("1.0e-1")), 1.0e-1) assert get_scalar(parse_expr("1.0e+1")) == 1.0e+1 - assert get_scalar(parse_expr("1.0E-1")) == 1.0E-1 + assert is_close(get_scalar(parse_expr("1.0E-1")), 1.0E-1) assert get_scalar(parse_expr("1.0E+1")) == 1.0E+1 def test_bool_literal(): @@ -183,7 +186,7 @@ def test_tuple(): def test_func(): # 0 args assert alpha_equal( - parse_expr("fn () -> { 0 }"), + parse_expr("fn () { 0 }"), relay.Function( [], relay.const(0), @@ -194,7 +197,7 @@ def test_func(): # 1 arg assert alpha_equal( - parse_expr("fn (%x) -> { %x }"), + parse_expr("fn (%x) { %x }"), relay.Function( [X], X, @@ -205,7 +208,7 @@ def test_func(): # 2 args assert alpha_equal( - parse_expr("fn (%x, %y) -> { %x + %y }"), + parse_expr("fn (%x, %y) { %x + %y }"), relay.Function( [X, Y], relay.add(X, Y), @@ -273,7 +276,7 @@ def test_call(): assert alpha_equal( parse_expr( """ - let %constant = fn () -> { 0 }; + let %constant = fn () { 0 }; %constant() """ ), @@ -289,7 +292,7 @@ def test_call(): assert alpha_equal( parse_expr( """ - let %id = fn (%x) -> { %x }; + let %id = fn (%x) { %x }; %id(1) """ ), @@ -305,7 +308,7 @@ def test_call(): assert alpha_equal( parse_expr( """ - let %multiply = fn (%x, %y) -> { %x * %y }; + let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) """ ), @@ -325,7 +328,7 @@ def test_call(): assert alpha_equal( parse_expr( """ - (fn (%x) -> { %x })(0) + (fn (%x) { %x })(0) """ ), relay.Call( @@ -347,8 +350,8 @@ def test_call(): parse_expr( """ let %curried_mult = - fn (%x) -> { - fn (%y) -> { + fn (%x) { + fn (%y) { %x * %y } }; @@ -441,7 +444,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: () -> Int32 = fn () -> Int32 { 0 }; () + let %_: fn () -> Int32 = fn () -> Int32 { 0 }; () """ ), relay.Let( @@ -454,7 +457,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: (Int32) -> Int32 = fn (%x: Int32) -> Int32 { 0 }; () + let %_: fn (Int32) -> Int32 = fn (%x: Int32) -> Int32 { 0 }; () """ ), relay.Let( @@ -467,7 +470,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: (Int32, Int32) -> Int32 = fn (%x: Int32, %y: Int32) -> Int32 { 0 }; () + let %_: fn (Int32, Int32) -> Int32 = fn (%x: Int32, %y: Int32) -> Int32 { 0 }; () """ ), relay.Let( From 2a1de38266f5105284952fbc7aa19f76fc1c5401 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 21 Oct 2018 17:12:08 -0700 Subject: [PATCH 31/64] semver --- python/tvm/relay/grammar/Relay.g4 | 2 ++ python/tvm/relay/parser.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 7516f89e7790..c4239e45c56c 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,3 +1,5 @@ +// semver: 0.1.0 + grammar Relay; // Lexing diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 72e3e3458113..bb75f524a54c 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,3 +1,5 @@ +# semver: 0.1.0 + # pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" from collections import deque From becedbf9377f6f56f67ec60b12814321949587b4 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 21 Oct 2018 17:16:30 -0700 Subject: [PATCH 32/64] add python type checking to git ignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 410a36aecdec..04dad2039860 100644 --- a/.gitignore +++ b/.gitignore @@ -209,3 +209,7 @@ tvm_t.* # patch sentinel patched.txt + +# Python type checking +.mypy_cache/ +.pyre/ From 955fc4e70055bd460f3e7d1ea8bef34b9013a912 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sun, 21 Oct 2018 18:40:42 -0700 Subject: [PATCH 33/64] make bools and builtin types more python-esque. realign antlr tags --- python/tvm/relay/grammar/Relay.g4 | 64 ++++++++------- python/tvm/relay/parser.py | 114 ++++++--------------------- tests/python/relay/test_ir_parser.py | 90 +++++++++++---------- 3 files changed, 104 insertions(+), 164 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index c4239e45c56c..2ac1312d8544 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,5 +1,3 @@ -// semver: 0.1.0 - grammar Relay; // Lexing @@ -27,8 +25,8 @@ LOCAL_VAR: '%' CNAME ; MUT: 'mut' ; BOOL_LIT - : 'true' - | 'false' + : 'True' + | 'False' ; // non-negative floats @@ -54,24 +52,24 @@ prog: /* option* */ defn* /* expr */ EOF ; expr // operators - : '(' expr ')' # parens - | '-' expr # neg - | expr op=('*'|'/') expr # binOp - | expr op=('+'|'-') expr # binOp - | expr op=('<'|'>'|'<='|'>=') expr # binOp - | expr op=('=='|'!=') expr # binOp + : '(' expr ')' # parens + | '-' expr # neg + | expr op=('*'|'/') expr # binOp + | expr op=('+'|'-') expr # binOp + | expr op=('<'|'>'|'<='|'>=') expr # binOp + | expr op=('=='|'!=') expr # binOp // function definition and application - | expr '(' (expr (',' expr)*)? ')' # call - | func # funcExpr + | expr '(' (expr (',' expr)*)? ')' # call + | func # funcExpr // tuples and tensors - | '(' ')' # tuple - | '(' expr ',' ')' # tuple - | '(' expr (',' expr)+ ')' # tuple - | '[' (expr (',' expr)*)? ']' # tensor + | '(' ')' # tuple + | '(' expr ',' ')' # tuple + | '(' expr (',' expr)+ ')' # tuple + | '[' (expr (',' expr)*)? ']' # tensor - | 'if' '(' expr ')' body 'else' body # ifElse + | 'if' '(' expr ')' body 'else' body # ifElse // sequencing | 'let' MUT? var '=' expr ';' expr # seq @@ -80,13 +78,13 @@ expr | expr ';' expr # seq // mutable update - // | ident '=' expr # writeRef - // | expr '^' # readRef + // | ident '=' expr # writeRef + // | expr '^' # readRef - | ident # identExpr - | scalar # scalarExpr - // | expr '.' INT # project - // | 'debug' # debug + | ident # identExpr + | scalar # scalarExpr + // | expr '.' INT # project + // | 'debug' # debug ; func: 'fn' varList ('->' type_)? body ; @@ -102,16 +100,16 @@ var: ident (':' type_)? ; // relation: ident '(' (type_ (',' type_)*)? ')' ; type_ - : '(' ')' # tupleType - | '(' type_ ',' ')' # tupleType - | '(' type_ (',' type_)+ ')' # tupleType - | identType # identTypeType - | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType - | identType '[' (type_ (',' type_)*)? ']' # callType - // Mut, Int, UInt, Float, Bool - | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType - | '_' # incompleteType - | INT # intType + : '(' ')' # tupleType + | '(' type_ ',' ')' # tupleType + | '(' type_ (',' type_)+ ')' # tupleType + | identType # identTypeType + | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType + // currently unused + | identType '[' (type_ (',' type_)*)? ']' # callType + | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + | '_' # incompleteType + | INT # intType ; shapeSeq diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index bb75f524a54c..3e638f470e88 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,5 +1,3 @@ -# semver: 0.1.0 - # pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" from collections import deque @@ -39,76 +37,12 @@ def __init__(self, message): RelayParser.NE: relay.not_equal, } -TYPES = { - "Int8": "int8", - "Int16": "int16", - "Int32": "int32", - "Int64": "int64", - - "UInt8": "uint8", - "UInt16": "uint16", - "UInt32": "uint32", - "UInt64": "uint64", - - "Float16": "float16", - "Float32": "float32", - "Float64": "float64", - - "Bool": "bool", -} - -def int_type_call(args): - # type: (List[relay.Expr]) -> relay.TensorType - """Turn an Int type call into a Relay TensorType""" - - if len(args) > 2: - raise ParseError("Int may have at most 2 arguments.") - - str_args = [str(arg) for arg in args] - - return relay.TensorType((), "int" + "x".join(str_args)) - -def uint_type_call(args): - # type: (List[relay.Expr]) -> relay.TensorType - """Turn a UInt type call into a Relay TensorType""" - - if len(args) > 2: - raise ParseError("UInt may have at most 2 arguments.") - - str_args = [str(arg) for arg in args] - - return relay.TensorType((), "uint" + "x".join(str_args)) - -def float_type_call(args): - # type: (List[relay.Expr]) -> relay.TensorType - """Turn a Float type call into a Relay TensorType""" - - if len(args) > 2: - raise ParseError("Float may have at most 2 arguments.") - - str_args = [str(arg) for arg in args] - - return relay.TensorType((), "float" + "x".join(str_args)) - -def bool_type_call(args): - # type: (List[relay.Expr]) -> relay.TensorType - """Turn a Bool type call into a Relay TensorType""" - - if len(args) > 1: - raise ParseError("Bool may have at most 1 argument.") - - # can't use bool, because ffi doesn't convert anything after bool - # bool is sugar for uint1 anyway - str_args = [str(arg) for arg in args] - - return relay.TensorType((), "uint1x" + str_args[0]) - -TYPE_FUNCS = { - "Int": int_type_call, - "UInt": uint_type_call, - "Float": float_type_call, - "Bool": bool_type_call, -} +TYPE_PREFIXES = [ + "int", + "uint", + "float", + "bool", +] T = TypeVar("T") Scope = Deque[Tuple[str, T]] @@ -203,9 +137,9 @@ def visitTerminal(self, node): elif node_type == RelayLexer.FLOAT: return float(node_text) elif node_type == RelayLexer.BOOL_LIT: - if node_text == "true": + if node_text == "True": return True - elif node_text == "false": + elif node_text == "False": return False else: raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) @@ -398,31 +332,29 @@ def visitIdentType(self, ctx): # type: (RelayParser.IdentTypeContext) -> Union[relay.TensorType, str] ident_type = ctx.CNAME().getText() - if not ident_type[0].isupper(): - raise ParseError("Types must start with capital letters.") + # look through all type prefixes for a match + for type_prefix in TYPE_PREFIXES: + if ident_type.startswith(type_prefix): + return relay.TensorType((), ident_type) - builtin_type = TYPES.get(ident_type) - - if builtin_type is None: - raise ParseError("Unknown builtin type: {}".format(ident_type)) - else: - return relay.TensorType((), builtin_type) + raise ParseError("Unknown builtin type: {}".format(ident_type)) def visitCallType(self, ctx): # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] - ident_type = ctx.identType().CNAME().getText() + # ident_type = ctx.identType().CNAME().getText() - args = self.visit_list(ctx.type_()) + # args = self.visit_list(ctx.type_()) - if not args: - raise ParseError("Type-level functions must have arguments!") + # if not args: + # raise ParseError("Type-level functions must have arguments!") - func_type = TYPE_FUNCS.get(ident_type)(args) + # func_type = TYPE_FUNCS.get(ident_type)(args) - if func_type is None: - raise ParseError("Unknown type-level function: `{}`".format(ident_type)) - else: - return func_type + # if func_type is None: + # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) + # else: + # return func_type + raise ParseError("Call types are unused!") def visitParensShape(self, ctx): # type: (RelayParser.ParensShapeContext) -> int diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 09c9643130c3..b4e2020e4101 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -20,28 +20,25 @@ } TYPES = { - "Int8", - "Int16", - "Int32", - "Int64", + "int8", + "int16", + "int32", + "int64", - "UInt8", - "UInt16", - "UInt32", - "UInt64", + "uint8", + "uint16", + "uint32", + "uint64", - "Float16", - "Float32", - "Float64", + "float16", + "float32", + "float64", - "Bool", -} + "bool", -CALL_TYPES = { - "Int": 2, - "UInt": 2, - "Float": 2, - "Bool": 1, + "int8x4", + "uint1x4", + "float16x4", } def get_scalar(x): @@ -65,6 +62,25 @@ def to_tensor_type(x): UNIT = relay.Tuple([]) +def test_comments(): + assert alpha_equal( + parse_expr(""" + // This is a line comment! + () + """), + UNIT + ) + + assert alpha_equal( + parse_expr(""" + /* This is a block comment! + This is still a block comment! + */ + () + """), + UNIT + ) + def test_int_literal(): assert isinstance(parse_expr("1"), relay.Constant) assert isinstance(parse_expr("1").data, tvm.ndarray.NDArray) @@ -92,8 +108,8 @@ def test_float_literal(): assert get_scalar(parse_expr("1.0E+1")) == 1.0E+1 def test_bool_literal(): - assert get_scalar(parse_expr("true")) == True - assert get_scalar(parse_expr("false")) == False + assert get_scalar(parse_expr("True")) == True + assert get_scalar(parse_expr("False")) == False def test_negative(): assert isinstance(parse_expr("let %x = 1; -%x").body, relay.Call) @@ -219,7 +235,7 @@ def test_func(): # annotations assert alpha_equal( - parse_expr("fn (%x: Int32) -> Int32 { %x }"), + parse_expr("fn (%x: int32) -> int32 { %x }"), relay.Function( [X_ANNO], X_ANNO, @@ -233,7 +249,7 @@ def test_func(): def test_defn(): id_defn = parse_prog( """ - def @id(%x: Int32) -> Int32 { + def @id(%x: int32) -> int32 { %x } """) @@ -243,7 +259,7 @@ def test_ifelse(): assert alpha_equal( parse_expr( """ - if (true) { + if (True) { 0 } else { 1 @@ -261,7 +277,7 @@ def test_ifelse(): def test_ifelse_scope(): parse_expr( """ - if (true) { + if (True) { let %x = (); () } else { @@ -402,19 +418,13 @@ def test_builtin_types(): for builtin_type in TYPES: parse_expr("let %_ : {} = (); ()".format(builtin_type)) +@nottest def test_call_type(): - # tests e.g. - # let %_ : Int[0] = (); () - # let %_ : Int[0, 1] = (); () - for call_type, arity in CALL_TYPES.items(): - args = [] - for i in range(arity): - args.append(i) - parse_expr("let %_ : {}{} = (); ()".format(call_type, args)) + assert False def test_tensor_type(): assert alpha_equal( - parse_expr("let %_ : Tensor[(), Float32] = (); ()"), + parse_expr("let %_ : Tensor[(), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, @@ -423,7 +433,7 @@ def test_tensor_type(): ) assert alpha_equal( - parse_expr("let %_ : Tensor[(1,), Float32] = (); ()"), + parse_expr("let %_ : Tensor[(1,), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, @@ -432,7 +442,7 @@ def test_tensor_type(): ) assert alpha_equal( - parse_expr("let %_ : Tensor[(1, 1), Float32] = (); ()"), + parse_expr("let %_ : Tensor[(1, 1), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, @@ -444,7 +454,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: fn () -> Int32 = fn () -> Int32 { 0 }; () + let %_: fn () -> int32 = fn () -> int32 { 0 }; () """ ), relay.Let( @@ -457,7 +467,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: fn (Int32) -> Int32 = fn (%x: Int32) -> Int32 { 0 }; () + let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """ ), relay.Let( @@ -470,7 +480,7 @@ def test_function_type(): assert alpha_equal( parse_expr( """ - let %_: fn (Int32, Int32) -> Int32 = fn (%x: Int32, %y: Int32) -> Int32 { 0 }; () + let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """ ), relay.Let( @@ -496,7 +506,7 @@ def test_tuple_type(): assert alpha_equal( parse_expr( """ - let %_: (Int32,) = (0,); () + let %_: (int32,) = (0,); () """), relay.Let( relay.Var("_", relay.TupleType([int32])), @@ -508,7 +518,7 @@ def test_tuple_type(): assert alpha_equal( parse_expr( """ - let %_: (Int32, Int32) = (0, 1); () + let %_: (int32, int32) = (0, 1); () """), relay.Let( relay.Var("_", relay.TupleType([int32, int32])), From 5de44c7a08c02ddacf676e9a9432b65514652a7c Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 22 Oct 2018 17:17:15 -0700 Subject: [PATCH 34/64] rm unused code --- python/tvm/relay/parser.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 3e638f470e88..28d39e280293 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -164,9 +164,6 @@ def getType_(self, ctx): def visitProg(self, ctx): # type: (RelayParser.ProgContext) -> relay.Environment - # if ctx.option(): - # raise ParseError("Compiler options are unimplemented.") - self.visit_list(ctx.defn()) return self.env From 8715a7ee922f1010742ba92898bfa97e0348dbf1 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 22 Oct 2018 17:22:25 -0700 Subject: [PATCH 35/64] linting --- python/tvm/relay/grammar/Relay.g4 | 2 +- python/tvm/relay/parser.py | 27 +++++++++++++++------------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 2ac1312d8544..71805c394c51 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -106,7 +106,7 @@ type_ | identType # identTypeType | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType // currently unused - | identType '[' (type_ (',' type_)*)? ']' # callType + // | identType '[' (type_ (',' type_)*)? ']' # callType | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType | INT # intType diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 28d39e280293..44dd7a05aa6b 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -213,6 +213,7 @@ def visitTuple(self, ctx): # Currently doesn't support mutable sequencing. def visitSeq(self, ctx): # type: (RelayParser.SeqContext) -> relay.Let + """Desugar various sequence constructs to Relay Let nodes.""" if ctx.MUT() is not None: raise ParseError("Mutation is currently unsupported.") @@ -307,6 +308,7 @@ def visitCall(self, ctx): def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> relay.If + """Construct a Relay If node. Creates a new scope for each branch.""" cond = self.visit(ctx.expr()) self.enter_var_scope() @@ -321,6 +323,7 @@ def visitIfElse(self, ctx): # Types + # pylint: disable=unused-argument def visitIncompleteType(self, ctx): # type (RelayParser.IncompleteTypeContext) -> None: return None @@ -336,22 +339,21 @@ def visitIdentType(self, ctx): raise ParseError("Unknown builtin type: {}".format(ident_type)) - def visitCallType(self, ctx): - # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] - # ident_type = ctx.identType().CNAME().getText() + # def visitCallType(self, ctx): + # # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] + # ident_type = ctx.identType().CNAME().getText() - # args = self.visit_list(ctx.type_()) + # args = self.visit_list(ctx.type_()) - # if not args: - # raise ParseError("Type-level functions must have arguments!") + # if not args: + # raise ParseError("Type-level functions must have arguments!") - # func_type = TYPE_FUNCS.get(ident_type)(args) + # func_type = TYPE_FUNCS.get(ident_type)(args) - # if func_type is None: - # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) - # else: - # return func_type - raise ParseError("Call types are unused!") + # if func_type is None: + # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) + # else: + # return func_type def visitParensShape(self, ctx): # type: (RelayParser.ParensShapeContext) -> int @@ -363,6 +365,7 @@ def visitShapeSeq(self, ctx): def visitTensorType(self, ctx): # type: (RelayParser.TensorTypeContext) -> relay.TensorType + """Create a simple tensor type. No generics.""" shape = self.visit(ctx.shapeSeq()) dtype = self.visit(ctx.type_()) From 4301600c35cb11a74c882c15dbbfbc0de586d353 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 23 Oct 2018 10:21:34 -0700 Subject: [PATCH 36/64] revert nullable alpha_equal changes --- src/relay/pass/alpha_eq.cc | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 095ef7f637b0..41ec3f1e090b 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -206,16 +206,6 @@ bool AlphaEqual(const Type& t1, const Type& t2) { return aeq.equal; } -bool NullableAlphaEqual(const Type& t1, const Type& t2) { - if (t1.defined() != t2.defined()) - return false; - - if (!t1.defined()) - return true; - - return AlphaEqual(t1, t2); -} - struct AlphaEq : ExprFunctor { public: tvm::Map eq_map; @@ -297,8 +287,7 @@ struct AlphaEq : ExprFunctor { } } - equal = equal && NullableAlphaEqual(func1->ret_type, func2->ret_type); - + equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); if (!equal) { return; } From c1db228dad9a65c01fc1542ed92cb75dbf9af5ca Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 25 Oct 2018 16:15:24 -0700 Subject: [PATCH 37/64] cmake config --- CMakeLists.txt | 28 ++-------------------------- cmake/config.cmake | 3 +++ cmake/modules/ANTLR.cmake | 29 +++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 26 deletions(-) create mode 100644 cmake/modules/ANTLR.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index f9fd78cdb950..363b2056a87a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) +tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) # include directories include_directories("include") @@ -183,6 +184,7 @@ include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) include(cmake/modules/SGX.cmake) include(cmake/modules/LLVM.cmake) +include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Sort.cmake) @@ -278,32 +280,6 @@ else(INSTALL_DEV) ) endif(INSTALL_DEV) -# ANTLR4 build definitions -find_program(ANTLR4 antlr4) - -set(RELAY_PARSER_DIR - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) - -set(RELAY_PARSER - ${RELAY_PARSER_DIR}/py2/RelayVisitor.py - ${RELAY_PARSER_DIR}/py2/RelayParser.py - ${RELAY_PARSER_DIR}/py2/RelayLexer.py - - ${RELAY_PARSER_DIR}/py3/RelayVisitor.py - ${RELAY_PARSER_DIR}/py3/RelayParser.py - ${RELAY_PARSER_DIR}/py3/RelayLexer.py) - -if(ANTLR4) - # Generate ANTLR grammar for parsing. - add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND antlr4 -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 - COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 - DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 - WORKING_DIRECTORY ${RELAY_PARSER_DIR}) -endif() - -add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) - # More target definitions if(MSVC) target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS) diff --git a/cmake/config.cmake b/cmake/config.cmake index a92be7ce3008..a97def410ddd 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -128,3 +128,6 @@ set(USE_ROCBLAS OFF) # Whether use contrib sort set(USE_SORT OFF) + +# Build ANTLR parser for Relay text format +set(USE_ANTLR OFF) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake new file mode 100644 index 000000000000..bfc8403a373c --- /dev/null +++ b/cmake/modules/ANTLR.cmake @@ -0,0 +1,29 @@ +find_program(ANTLR4 antlr4) + +if(USE_ANTLR) + find_program(ANTLR4 antlr4) + if(NOT ANTLR4) + message(FATAL_ERROR "Can't find ANTLR4!") + endif() + + set(RELAY_PARSER_DIR + ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) + + set(RELAY_PARSER + ${RELAY_PARSER_DIR}/py2/RelayVisitor.py + ${RELAY_PARSER_DIR}/py2/RelayParser.py + ${RELAY_PARSER_DIR}/py2/RelayLexer.py + + ${RELAY_PARSER_DIR}/py3/RelayVisitor.py + ${RELAY_PARSER_DIR}/py3/RelayParser.py + ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + + # Generate ANTLR grammar for parsing. + add_custom_command(OUTPUT ${RELAY_PARSER} + COMMAND antlr4 -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 + WORKING_DIRECTORY ${RELAY_PARSER_DIR}) + + add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) +endif(USE_ANTLR) From f6991fce5c47eedbf2e13d6d7b120bb3175ba2ca Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 25 Oct 2018 16:22:38 -0700 Subject: [PATCH 38/64] delete alpha_eq --- src/relay/pass/alpha_eq.cc | 418 ------------------------------------- 1 file changed, 418 deletions(-) delete mode 100644 src/relay/pass/alpha_eq.cc diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc deleted file mode 100644 index 41ec3f1e090b..000000000000 --- a/src/relay/pass/alpha_eq.cc +++ /dev/null @@ -1,418 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pass/alpha_eq.cc - * \brief Check that two type are syntactically equal up to alpha equivalence. - */ -#include -#include -#include -#include "./type_visitor.h" -#include "tvm/relay/pass.h" - -namespace tvm { -namespace relay { - -using namespace tvm::runtime; - -bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { - if (lhs.defined() != rhs.defined()) { - return false; - } else if (lhs.same_as(rhs)) { - return true; - } else { - auto ldt = lhs->dtype; - auto rdt = rhs->dtype; - CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t s = GetDataSize(*lhs.operator->()); - return memcmp(lhs->data, rhs->data, s) == 0; - } else { - return false; - } - } -} - -struct TypeAlphaEq : TypeVisitor { - tvm::Map eq_map; - bool equal; - - TypeAlphaEq() : eq_map(), equal(true) {} - - void DataTypeEqual(const DataType& dt1, const DataType& dt2) { - if (dt1 != dt2) { - equal = false; - } - } - - void ShapeEqual(const Array& s1, const Array& s2) { - if (s1.size() != s2.size()) { - equal = false; - return; - } - for (size_t i = 0; i < s1.size(); ++i) { - if (!tvm::ir::Equal(s1[i], s2[i])) { - equal = false; - return; - } - } - } - - void VisitType_(const TensorTypeNode* tt1, const Type& t2) final { - if (const TensorTypeNode* tt2 = t2.as()) { - DataTypeEqual(tt1->dtype, tt2->dtype); - ShapeEqual(tt1->shape, tt2->shape); - } else { - equal = false; - } - } - - void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { - if (const IncompleteTypeNode* bt2 = t2.as()) { - equal = equal && bt1 == bt2; - return; - } else { - equal = false; - } - } - - void VisitType_(const TypeVarNode* ti1, const Type& t2) final { - if (const TypeVarNode* ti2 = t2.as()) { - auto tid1 = GetRef(ti1); - auto tid2 = GetRef(ti2); - - // We handle open terms with this rule assuming variables are identical. - // - // Not sure if we should do this. - if (tid1 == tid2) { - return; - } - - // Check that they are same kind - if (tid1->kind != tid2->kind) { - equal = false; - return; - } - - // Next we see if there is mapping for local1 into the rhs term. - // If there is we check to see if those are equal. - if (eq_map.find(tid1) != eq_map.end()) { - equal = equal && eq_map[tid1] == tid2; - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitType_(const FuncTypeNode* op, const Type& t2) final { - if (const FuncTypeNode* ta2 = t2.as()) { - if (op->arg_types.size() != ta2->arg_types.size() - || op->type_params.size() != ta2->type_params.size() - || op->type_constraints.size() != ta2->type_constraints.size()) { - equal = false; - return; - } - - // must visit params first so they are appropriate entered - // into equality map - for (size_t i = 0; i < op->type_params.size(); i++) { - eq_map.Set(op->type_params[i], ta2->type_params[i]); - this->VisitType(op->type_params[i], ta2->type_params[i]); - if (!equal) { - return; - } - } - - for (size_t i = 0; i < op->arg_types.size(); i++) { - this->VisitType(op->arg_types[i], ta2->arg_types[i]); - if (!equal) { - return; - } - } - - this->VisitType(op->ret_type, ta2->ret_type); - if (!equal) { - return; - } - - for (size_t i = 0; i < op->type_constraints.size(); i++) { - this->VisitType(op->type_constraints[i], ta2->type_constraints[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitType_(const TypeRelationNode* tr1, const Type& t2) final { - if (const TypeRelationNode* tr2 = t2.as()) { - if (tr1->func != tr2->func - || tr1->num_inputs != tr2->num_inputs - || tr1->attrs != tr2->attrs) { - equal = false; - return; - } - - if (tr1->args.size() != tr2->args.size()) { - equal = false; - return; - } - - for (size_t i = 0; i < tr1->args.size(); i++) { - this->VisitType(tr1->args[i], tr2->args[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitType_(const TupleTypeNode* op, const Type& t2) final { - if (const TupleTypeNode* pt = t2.as()) { - if (op->fields.size() != pt->fields.size()) { - equal = false; - return; - } - - for (size_t i = 0U; i < op->fields.size(); i++) { - if (!equal) { - return; - } - this->VisitType(op->fields[i], pt->fields[i]); - } - } else { - equal = false; - } - } -}; - -bool AlphaEqual(const Type& t1, const Type& t2) { - if (t1.defined() != t2.defined()) { - return false; - } - - if (!t1.defined()) { - return true; - } - - TypeAlphaEq aeq; - aeq.VisitType(t1, t2); - return aeq.equal; -} - -struct AlphaEq : ExprFunctor { - public: - tvm::Map eq_map; - - bool equal; - AlphaEq() : eq_map(), equal(true) {} - - void VisitExpr_(const VarNode* e1, const Expr& e2) final { - if (const VarNode* id2 = e2.as()) { - auto local1 = GetRef(e1); - auto local2 = GetRef(id2); - // We handle open terms with this rule assuming variables are identical. - if (local1 == local2) { - equal = true; - return; - } - - // Next we see if there is mapping for local1 into the rhs term. - // If there is we check to see if those are equal. - if (eq_map.find(local1) != eq_map.end()) { - equal = equal && eq_map[local1] == local2; - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final { - if (const GlobalVarNode* g2 = e2.as()) { - equal = equal && g1 == g2; - } else { - equal = false; - } - } - - void VisitExpr_(const TupleNode* pl1, const Expr& e2) final { - Tuple prod1 = GetRef(pl1); - if (const TupleNode* pl2 = e2.as()) { - Tuple prod2 = GetRef(pl2); - if (prod1->fields.size() != prod2->fields.size()) { - equal = false; - return; - } - - for (size_t i = 0U; i < prod1->fields.size(); i++) { - this->VisitExpr(prod1->fields[i], prod2->fields[i]); - } - } else { - equal = false; - } - } - - void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { - if (const FunctionNode* func2 = e2.as()) { - if (func1->params.size() != func2->params.size()) { - equal = false; - return; - } - - if (func1->type_params.size() != func2->type_params.size()) { - equal = false; - return; - } - - for (size_t i = 0; i < func1->params.size(); ++i) { - MergeVarDecl(func1->params[i], func2->params[i]); - } - - if (!equal) { - return; - } - - for (size_t i = 0U; i < func1->type_params.size(); i++) { - equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); - if (!equal) { - return; - } - } - - equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); - if (!equal) { - return; - } - - this->VisitExpr(func1->body, func2->body); - } else { - equal = false; - } - } - - void VisitExpr_(const CallNode* op, const Expr& e2) final { - if (const CallNode* call = e2.as()) { - this->VisitExpr(op->op, call->op); - - if (op->args.size() != call->args.size()) { - equal = false; - return; - } - - if (op->type_args.size() != call->type_args.size()) { - equal = false; - return; - } - - // checking attrs by pointer equality for now - equal = equal && (op->attrs == call->attrs); - if (!equal) { - return; - } - - for (size_t i = 0U; i < op->args.size(); i++) { - this->VisitExpr(op->args[i], call->args[i]); - } - - for (size_t i = 0U; i < op->type_args.size(); i++) { - equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitExpr_(const LetNode* op, const Expr& e2) final { - if (const LetNode* let = e2.as()) { - MergeVarDecl(op->var, let->var); - this->VisitExpr(op->value, let->value); - this->VisitExpr(op->body, let->body); - } else { - equal = false; - } - } - - void VisitExpr_(const IfNode* op, const Expr& e2) final { - if (const IfNode* i = e2.as()) { - VisitExpr(op->cond, i->cond); - VisitExpr(op->true_branch, i->true_branch); - VisitExpr(op->false_branch, i->false_branch); - } else { - equal = false; - } - } - - void VisitExpr_(const OpNode* op, const Expr& e2) final { - if (const OpNode* o = e2.as()) { - equal = equal && op->name == o->name; - } else { - equal = false; - } - } - - void VisitExpr_(const ConstantNode* op, const Expr& e2) final { - if (const ConstantNode* c = e2.as()) { - if (AlphaEqual(op->tensor_type(), c->tensor_type())) { - equal = equal && SameNDArray(op->data, c->data); - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final { - if (const TupleGetItemNode* proj = e2.as()) { - this->VisitExpr(op->tuple, proj->tuple); - equal = equal && (op->index == proj->index); - } else { - equal = false; - } - } - - private: - void MergeVarDecl(const Var& var1, const Var& var2) { - equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation); - if (!equal) { - return; - } - - eq_map.Set(var1, var2); - } -}; - -bool AlphaEqual(const Expr& e1, const Expr& e2) { - AlphaEq eq; - eq.VisitExpr(e1, e2); - return eq.equal; -} - -// TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_equal") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Expr e1 = args[0]; - Expr e2 = args[1]; - *ret = AlphaEqual(e1, e2); - }); - -TVM_REGISTER_API("relay._make._type_alpha_equal") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Type t1 = args[0]; - Type t2 = args[1]; - *ret = AlphaEqual(t1, t2); - }); - -} // namespace relay -} // namespace tvm From 96960bd9f7045963aec13b12c42aff6a283afc39 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 25 Oct 2018 21:34:53 -0700 Subject: [PATCH 39/64] use scalar_type --- python/tvm/relay/parser.py | 2 +- tests/python/relay/test_ir_parser.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 44dd7a05aa6b..20237d0db72d 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -335,7 +335,7 @@ def visitIdentType(self, ctx): # look through all type prefixes for a match for type_prefix in TYPE_PREFIXES: if ident_type.startswith(type_prefix): - return relay.TensorType((), ident_type) + return relay.scalar_type(ident_type) raise ParseError("Unknown builtin type: {}".format(ident_type)) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index b4e2020e4101..5792dfa133d7 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -48,11 +48,7 @@ def get_scalar(x): def is_close(x, y, precision=0.001): return x - y < precision and y - x < precision -def to_tensor_type(x): - # type: (str) -> relay.TensorType - return relay.TensorType([], x) - -int32 = to_tensor_type("int32") +int32 = relay.scalar_type("int32") _ = relay.Var("_") X = relay.Var("x") From d64052553eb6580412ad4799aa3507b359fdc13d Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 15:34:32 -0700 Subject: [PATCH 40/64] expose parse and parse_file via __init__.py --- python/tvm/relay/__init__.py | 5 + python/tvm/relay/grammar/Relay.g4 | 4 +- python/tvm/relay/parser.py | 193 ++++++++++++++------------- tests/python/relay/test_ir_parser.py | 140 +++++++++---------- 4 files changed, 177 insertions(+), 165 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 64d6774d0bde..32bdd4ee5f8a 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -8,6 +8,7 @@ from . import module from . import ir_pass from .build_module import build, build_config, create_executor +from . import parser # Root operators from .op import Op @@ -62,3 +63,7 @@ def _debug(*args): import pdb pdb.set_trace() + +# Parser +parse = parser.parse +parse_file = parser.parse_file diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 71805c394c51..c74a42c97e77 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -45,8 +45,8 @@ fragment DIGIT: [0-9] ; // Parsing -// a program is a list of options, a list of global definitions, and an expression -prog: /* option* */ defn* /* expr */ EOF ; +// A Relay program is a list of global definitions or an expression. +prog: (defn* | expr) EOF ; // option: 'set' ident BOOL_LIT ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 20237d0db72d..98b5bc078d27 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,21 +1,5 @@ # pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" -from collections import deque -import sys -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any -from antlr4 import ParserRuleContext, InputStream, CommonTokenStream -from antlr4.tree.Tree import TerminalNode -import tvm -from tvm import relay -if sys.version_info.major < 3: - from .grammar.py2.RelayVisitor import RelayVisitor - from .grammar.py2.RelayParser import RelayParser - from .grammar.py2.RelayLexer import RelayLexer -else: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer - class ParseError(Exception): """Exception type for parse errors.""" @@ -24,17 +8,46 @@ def __init__(self, message): super(ParseError, self).__init__() self.message = message +import sys +PYTHON_VERSION = sys.version_info.major + +from collections import deque +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any + +from . import env +from . import expr +from . import ty +from . import op + +try: + if PYTHON_VERSION == 2: + from .grammar.py2.RelayVisitor import RelayVisitor + from .grammar.py2.RelayParser import RelayParser + from .grammar.py2.RelayLexer import RelayLexer + else: + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer +except ImportError: + raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") + +try: + from antlr4 import ParserRuleContext, InputStream, CommonTokenStream + from antlr4.tree.Tree import TerminalNode +except ImportError: + raise ParseError("Couldn't find ANTLR runtime. Try running `pip{} install antlr{}-runtime`.".format(PYTHON_VERSION, PYTHON_VERSION)) + BINARY_OPS = { - RelayParser.MUL: relay.multiply, - RelayParser.DIV: relay.divide, - RelayParser.ADD: relay.add, - RelayParser.SUB: relay.subtract, - RelayParser.LT: relay.less, - RelayParser.GT: relay.greater, - RelayParser.LE: relay.less_equal, - RelayParser.GE: relay.greater_equal, - RelayParser.EQ: relay.equal, - RelayParser.NE: relay.not_equal, + RelayParser.MUL: op.multiply, + RelayParser.DIV: op.divide, + RelayParser.ADD: op.add, + RelayParser.SUB: op.subtract, + RelayParser.LT: op.less, + RelayParser.GT: op.greater, + RelayParser.LE: op.less_equal, + RelayParser.GE: op.greater_equal, + RelayParser.EQ: op.equal, + RelayParser.NE: op.not_equal, } TYPE_PREFIXES = [ @@ -65,11 +78,11 @@ class ParseTreeToRelayIR(RelayVisitor): def __init__(self): # type: () -> None - self.env = relay.Environment({}) # type: relay.Environment + self.env = env.Environment({}) # type: env.Environment # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[relay.Var] - self.type_param_scopes = deque([deque()]) # type: Scopes[relay.TypeParam] + self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeParam] super(ParseTreeToRelayIR, self).__init__() @@ -80,16 +93,16 @@ def enter_var_scope(self): self.var_scopes.appendleft(deque()) def exit_var_scope(self): - # type: () -> Scope[relay.Var] + # type: () -> Scope[expr.Var] """Pop off the current Var scope and return it.""" return self.var_scopes.popleft() def mk_var(self, name, type_): - # type: (str, relay.Type) -> relay.Var + # type: (str, ty.Type) -> expr.Var """Create a new Var and add it to the Var scope.""" - var = relay.Var(name, type_) + var = expr.Var(name, type_) self.var_scopes[0].appendleft((name, var)) return var @@ -100,21 +113,21 @@ def enter_type_param_scope(self): self.type_param_scopes.appendleft(deque()) def exit_type_param_scope(self): - # type: () -> Scope[relay.TypeParam] + # type: () -> Scope[ty.TypeParam] """Pop off the current TypeParam scope and return it.""" return self.type_param_scopes.popleft() def mk_typ(self, name, kind): - # (str, relay.Kind) -> relay.TypeParam + # (str, ty.Kind) -> ty.TypeParam """Create a new TypeParam and add it to the TypeParam scope.""" - typ = relay.TypeParam(name, kind) + typ = TypeParam(name, kind) self.type_param_scopes[0].appendleft((name, typ)) return typ def visitTerminal(self, node): - # type: (TerminalNode) -> Union[relay.Expr, int, float] + # type: (TerminalNode) -> Union[expr.Expr, int, float] """Visit lexer tokens that aren't ignored or visited by other functions.""" node_type = node.getSymbol().type @@ -122,7 +135,7 @@ def visitTerminal(self, node): # variables if node_type == RelayLexer.GLOBAL_VAR: - return relay.GlobalVar(node_text[1:]) + return GlobalVar(node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: name = node_text[1:] var = lookup(self.var_scopes, name) @@ -154,7 +167,7 @@ def visit_list(self, ctx_list): return [self.visit(ctx) for ctx in ctx_list] def getType_(self, ctx): - # type: (Optional[RelayParser.Type_Context]) -> Optional[relay.Type] + # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] """Return a (possibly None) Relay type.""" if ctx is None: @@ -163,56 +176,58 @@ def getType_(self, ctx): return self.visit(ctx) def visitProg(self, ctx): - # type: (RelayParser.ProgContext) -> relay.Environment - self.visit_list(ctx.defn()) - - return self.env + # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] + if ctx.defn(): + self.visit_list(ctx.defn()) + return self.env + else: + return self.visit(ctx.expr()) # Exprs def visitOpIdent(self, ctx): - # type: (RelayParser.OpIdentContext) -> relay.Op - return relay.op.get(ctx.CNAME().getText()) + # type: (RelayParser.OpIdentContext) -> op.Op + return op.get(ctx.CNAME().getText()) # pass through def visitParens(self, ctx): - # type: (RelayParser.ParensContext) -> relay.Expr + # type: (RelayParser.ParensContext) -> expr.Expr return self.visit(ctx.expr()) # pass through def visitBody(self, ctx): - # type: (RelayParser.BodyContext) -> relay.Expr + # type: (RelayParser.BodyContext) -> expr.Expr return self.visit(ctx.expr()) def visitScalarFloat(self, ctx): - # type: (RelayParser.ScalarFloatContext) -> relay.Constant - return relay.const(self.visit(ctx.FLOAT())) + # type: (RelayParser.ScalarFloatContext) -> expr.Constant + return expr.const(self.visit(ctx.FLOAT())) def visitScalarInt(self, ctx): - # type: (RelayParser.ScalarIntContext) -> relay.Constant - return relay.const(self.visit(ctx.INT())) + # type: (RelayParser.ScalarIntContext) -> expr.Constant + return expr.const(self.visit(ctx.INT())) def visitScalarBool(self, ctx): - # type: (RelayParser.ScalarBoolContext) -> relay.Constant - return relay.const(self.visit(ctx.BOOL_LIT())) + # type: (RelayParser.ScalarBoolContext) -> expr.Constant + return expr.const(self.visit(ctx.BOOL_LIT())) def visitNeg(self, ctx): - # type: (RelayParser.NegContext) -> Union[relay.Constant, relay.Call] + # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] val = self.visit(ctx.expr()) - if isinstance(val, relay.Constant) and val.data.asnumpy().ndim == 0: + if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: # fold Neg in for scalars - return relay.const(-val.data.asnumpy().item()) + return expr.const(-val.data.asnumpy().item()) - return relay.negative(val) + return op.negative(val) def visitTuple(self, ctx): - # type: (RelayParser.TupleContext) -> relay.Tuple + # type: (RelayParser.TupleContext) -> expr.Tuple tup = self.visit_list(ctx.expr()) - return relay.Tuple(tup) + return expr.Tuple(tup) # Currently doesn't support mutable sequencing. def visitSeq(self, ctx): - # type: (RelayParser.SeqContext) -> relay.Let + # type: (RelayParser.SeqContext) -> expr.Let """Desugar various sequence constructs to Relay Let nodes.""" if ctx.MUT() is not None: raise ParseError("Mutation is currently unsupported.") @@ -236,10 +251,10 @@ def visitSeq(self, ctx): body = self.visit(ctx.expr(1)) - return relay.Let(var, value, body) + return expr.Let(var, value, body) def visitBinOp(self, ctx): - # type: (RelayParser.BinOpContext) -> relay.Call + # type: (RelayParser.BinOpContext) -> expr.Call """Desugar binary operators.""" arg0, arg1 = self.visit_list(ctx.expr()) relay_op = BINARY_OPS.get(ctx.op.type) @@ -250,7 +265,7 @@ def visitBinOp(self, ctx): return relay_op(arg0, arg1) def visitVar(self, ctx): - # type: (RelayParser.VarContext) -> relay.Var + # type: (RelayParser.VarContext) -> expr.Var ident = ctx.ident().LOCAL_VAR() if ident is None: @@ -261,11 +276,11 @@ def visitVar(self, ctx): return self.mk_var(ident.getText()[1:], type_) def visitVarList(self, ctx): - # type: (RelayParser.VarListContext) -> List[relay.Var] + # type: (RelayParser.VarListContext) -> List[expr.Var] return self.visit_list(ctx.var()) def mk_func(self, ctx): - # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> relay.Function + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. @@ -282,10 +297,10 @@ def mk_func(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return relay.Function(var_list, body, ret_type, type_params) # type: ignore + return expr.Function(var_list, body, ret_type, type_params) # type: ignore def visitFunc(self, ctx): - # type: (RelayParser.FuncContext) -> relay.Function + # type: (RelayParser.FuncContext) -> expr.Function return self.mk_func(ctx) def visitDefn(self, ctx): @@ -293,21 +308,21 @@ def visitDefn(self, ctx): ident = ctx.ident().GLOBAL_VAR() if ident is None: raise ParseError('Only global ids may be used in `def`s.') - ident = relay.GlobalVar(ident.getText()[1:]) + ident = expr.GlobalVar(ident.getText()[1:]) self.env[ident] = self.mk_func(ctx) def visitCall(self, ctx): - # type: (RelayParser.CallContext) -> relay.Call + # type: (RelayParser.CallContext) -> expr.Call visited_exprs = self.visit_list(ctx.expr()) func = visited_exprs[0] args = visited_exprs[1:] - return relay.Call(func, args, None, None) + return expr.Call(func, args, None, None) def visitIfElse(self, ctx): - # type: (RelayParser.IfElseContext) -> relay.If + # type: (RelayParser.IfElseContext) -> expr.If """Construct a Relay If node. Creates a new scope for each branch.""" cond = self.visit(ctx.expr()) @@ -319,7 +334,7 @@ def visitIfElse(self, ctx): false_branch = self.visit(ctx.body(1)) self.exit_var_scope() - return relay.If(cond, true_branch, false_branch) + return expr.If(cond, true_branch, false_branch) # Types @@ -329,18 +344,18 @@ def visitIncompleteType(self, ctx): return None def visitIdentType(self, ctx): - # type: (RelayParser.IdentTypeContext) -> Union[relay.TensorType, str] + # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] ident_type = ctx.CNAME().getText() # look through all type prefixes for a match for type_prefix in TYPE_PREFIXES: if ident_type.startswith(type_prefix): - return relay.scalar_type(ident_type) + return ty.scalar_type(ident_type) raise ParseError("Unknown builtin type: {}".format(ident_type)) # def visitCallType(self, ctx): - # # type: (RelayParser.CallTypeContext) -> Union[relay.Expr, relay.TensorType] + # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] # ident_type = ctx.identType().CNAME().getText() # args = self.visit_list(ctx.type_()) @@ -364,31 +379,31 @@ def visitShapeSeq(self, ctx): return self.visit_list(ctx.shape()) def visitTensorType(self, ctx): - # type: (RelayParser.TensorTypeContext) -> relay.TensorType + # type: (RelayParser.TensorTypeContext) -> ty.TensorType """Create a simple tensor type. No generics.""" shape = self.visit(ctx.shapeSeq()) dtype = self.visit(ctx.type_()) - if not isinstance(dtype, relay.TensorType): + if not isinstance(dtype, ty.TensorType): raise ParseError("Expected dtype to be a Relay base type.") dtype = dtype.dtype - return relay.TensorType(shape, dtype) + return ty.TensorType(shape, dtype) def visitTupleType(self, ctx): - # type: (RelayParser.TupleTypeContext) -> relay.TupleType - return relay.TupleType(self.visit_list(ctx.type_())) + # type: (RelayParser.TupleTypeContext) -> ty.TupleType + return ty.TupleType(self.visit_list(ctx.type_())) def visitFuncType(self, ctx): - # type: (RelayParser.FuncTypeContext) -> relay.FuncType + # type: (RelayParser.FuncTypeContext) -> ty.FuncType types = self.visit_list(ctx.type_()) arg_types = types[:-1] ret_type = types[-1] - return relay.FuncType(arg_types, ret_type, [], None) + return ty.FuncType(arg_types, ret_type, [], None) def make_parser(data): # type: (str) -> RelayParser @@ -399,23 +414,15 @@ def make_parser(data): token_stream = CommonTokenStream(lexer) return RelayParser(token_stream) -def parse_expr(data): - # type: (str) -> relay.Expr - """Parse a Relay expression.""" - - tree = make_parser(data).expr() - return ParseTreeToRelayIR().visit(tree) - -def parse_prog(data): - # type: (str) -> Program +def parse(data): + # type: (str) -> Union[expr.Expr, env.Environment] """Parse a Relay program.""" - tree = make_parser(data).prog() return ParseTreeToRelayIR().visit(tree) def parse_file(path): - # type: (str) -> Program + # type: (str) -> Union[expr.Expr, env.Environment] """Parse a Relay program from a file.""" with open(path, 'r') as in_file: - return parse_prog(in_file.read()) + return parse(in_file.read()) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 5792dfa133d7..9aedbf9c1300 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,8 +1,7 @@ import tvm from tvm import relay -from tvm.relay.parser import parse_expr, parse_prog, ParseError +from tvm.relay.parser import ParseError from tvm.relay.ir_pass import alpha_equal -# from tvm.relay.ir_builder import convert from nose.tools import nottest, raises from typing import Union @@ -60,7 +59,7 @@ def is_close(x, y, precision=0.001): def test_comments(): assert alpha_equal( - parse_expr(""" + relay.parse(""" // This is a line comment! () """), @@ -68,7 +67,7 @@ def test_comments(): ) assert alpha_equal( - parse_expr(""" + relay.parse(""" /* This is a block comment! This is still a block comment! */ @@ -78,81 +77,82 @@ def test_comments(): ) def test_int_literal(): - assert isinstance(parse_expr("1"), relay.Constant) - assert isinstance(parse_expr("1").data, tvm.ndarray.NDArray) + assert isinstance(relay.parse("1"), relay.Constant) + assert isinstance(relay.parse("1").data, tvm.ndarray.NDArray) - assert get_scalar(parse_expr("1")) == 1 - assert get_scalar(parse_expr("10")) == 10 - assert get_scalar(parse_expr("0")) == 0 - assert get_scalar(parse_expr("-100")) == -100 - assert get_scalar(parse_expr("-05")) == -5 + assert get_scalar(relay.parse("1")) == 1 + assert get_scalar(relay.parse("10")) == 10 + assert get_scalar(relay.parse("0")) == 0 + assert get_scalar(relay.parse("-100")) == -100 + assert get_scalar(relay.parse("-05")) == -5 def test_float_literal(): - assert get_scalar(parse_expr("1.0")) == 1.0 - assert is_close(get_scalar(parse_expr("1.56667")), 1.56667) - assert get_scalar(parse_expr("0.0")) == 0.0 - assert get_scalar(parse_expr("-10.0")) == -10.0 + assert get_scalar(relay.parse("1.0")) == 1.0 + assert is_close(get_scalar(relay.parse("1.56667")), 1.56667) + assert get_scalar(relay.parse("0.0")) == 0.0 + assert get_scalar(relay.parse("-10.0")) == -10.0 # scientific notation - assert is_close(get_scalar(parse_expr("1e-1")), 1e-1) - assert get_scalar(parse_expr("1e+1")) == 1e+1 - assert is_close(get_scalar(parse_expr("1E-1")), 1E-1) - assert get_scalar(parse_expr("1E+1")) == 1E+1 - assert is_close(get_scalar(parse_expr("1.0e-1")), 1.0e-1) - assert get_scalar(parse_expr("1.0e+1")) == 1.0e+1 - assert is_close(get_scalar(parse_expr("1.0E-1")), 1.0E-1) - assert get_scalar(parse_expr("1.0E+1")) == 1.0E+1 + assert is_close(get_scalar(relay.parse("1e-1")), 1e-1) + assert get_scalar(relay.parse("1e+1")) == 1e+1 + assert is_close(get_scalar(relay.parse("1E-1")), 1E-1) + assert get_scalar(relay.parse("1E+1")) == 1E+1 + assert is_close(get_scalar(relay.parse("1.0e-1")), 1.0e-1) + assert get_scalar(relay.parse("1.0e+1")) == 1.0e+1 + assert is_close(get_scalar(relay.parse("1.0E-1")), 1.0E-1) + assert get_scalar(relay.parse("1.0E+1")) == 1.0E+1 def test_bool_literal(): - assert get_scalar(parse_expr("True")) == True - assert get_scalar(parse_expr("False")) == False + assert get_scalar(relay.parse("True")) == True + assert get_scalar(relay.parse("False")) == False def test_negative(): - assert isinstance(parse_expr("let %x = 1; -%x").body, relay.Call) - assert get_scalar(parse_expr("--10")) == 10 - assert get_scalar(parse_expr("---10")) == -10 + assert isinstance(relay.parse("let %x = 1; -%x").body, relay.Call) + assert get_scalar(relay.parse("--10")) == 10 + assert get_scalar(relay.parse("---10")) == -10 def test_bin_op(): for bin_op in BINARY_OPS.keys(): assert alpha_equal( - parse_expr("1 {} 1".format(bin_op)), + relay.parse("1 {} 1".format(bin_op)), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) def test_parens(): - assert alpha_equal(parse_expr("1 * 1 + 1"), parse_expr("(1 * 1) + 1")) - assert not alpha_equal(parse_expr("1 * 1 + 1"), parse_expr("1 * (1 + 1)")) + print(relay.parse("1 * 1 + 1")) + assert alpha_equal(relay.parse("1 * 1 + 1"), relay.parse("(1 * 1) + 1")) + assert not alpha_equal(relay.parse("1 * 1 + 1"), relay.parse("1 * (1 + 1)")) def test_op_assoc(): - assert alpha_equal(parse_expr("1 * 1 + 1 < 1 == 1"), parse_expr("(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(parse_expr("1 == 1 < 1 + 1 * 1"), parse_expr("1 == (1 < (1 + (1 * 1)))")) + assert alpha_equal(relay.parse("1 * 1 + 1 < 1 == 1"), relay.parse("(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(relay.parse("1 == 1 < 1 + 1 * 1"), relay.parse("1 == (1 < (1 + (1 * 1)))")) @nottest def test_vars(): # temp vars won't work b/c they start with a digit # # temp var - # temp_var = parse_expr("%1") + # temp_var = relay.parse("%1") # assert isinstance(temp_var, relay.Var) # assert temp_var.name == "1" # var - var = parse_expr("let %foo = (); %foo") + var = relay.parse("let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var - global_var = parse_expr("@foo") + global_var = relay.parse("@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id - op = parse_expr("foo") + op = relay.parse("foo") assert isinstance(op, relay.Op) assert op.name == "foo" def test_let(): assert alpha_equal( - parse_expr("let %x = 1; ()"), + relay.parse("let %x = 1; ()"), relay.Let( X, relay.const(1), @@ -162,7 +162,7 @@ def test_let(): def test_seq(): assert alpha_equal( - parse_expr("(); ()"), + relay.parse("(); ()"), relay.Let( _, UNIT, @@ -170,7 +170,7 @@ def test_seq(): ) assert alpha_equal( - parse_expr("let %_ = { 1 }; ()"), + relay.parse("let %_ = { 1 }; ()"), relay.Let( X, relay.const(1), @@ -180,25 +180,25 @@ def test_seq(): @raises(ParseError) def test_let_global_var(): - parse_expr("let @x = 1; ()") + relay.parse("let @x = 1; ()") @raises(ParseError) def test_let_op(): - parse_expr("let x = 1; ()") + relay.parse("let x = 1; ()") def test_tuple(): - assert alpha_equal(parse_expr("()"), relay.Tuple([])) + assert alpha_equal(relay.parse("()"), relay.Tuple([])) - assert alpha_equal(parse_expr("(0,)"), relay.Tuple([relay.const(0)])) + assert alpha_equal(relay.parse("(0,)"), relay.Tuple([relay.const(0)])) - assert alpha_equal(parse_expr("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) + assert alpha_equal(relay.parse("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) - assert alpha_equal(parse_expr("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + assert alpha_equal(relay.parse("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) def test_func(): # 0 args assert alpha_equal( - parse_expr("fn () { 0 }"), + relay.parse("fn () { 0 }"), relay.Function( [], relay.const(0), @@ -209,7 +209,7 @@ def test_func(): # 1 arg assert alpha_equal( - parse_expr("fn (%x) { %x }"), + relay.parse("fn (%x) { %x }"), relay.Function( [X], X, @@ -220,7 +220,7 @@ def test_func(): # 2 args assert alpha_equal( - parse_expr("fn (%x, %y) { %x + %y }"), + relay.parse("fn (%x, %y) { %x + %y }"), relay.Function( [X, Y], relay.add(X, Y), @@ -231,7 +231,7 @@ def test_func(): # annotations assert alpha_equal( - parse_expr("fn (%x: int32) -> int32 { %x }"), + relay.parse("fn (%x: int32) -> int32 { %x }"), relay.Function( [X_ANNO], X_ANNO, @@ -243,7 +243,7 @@ def test_func(): # TODO(@jmp): Crashes if %x isn't annnotated. # @nottest def test_defn(): - id_defn = parse_prog( + id_defn = relay.parse( """ def @id(%x: int32) -> int32 { %x @@ -253,7 +253,7 @@ def @id(%x: int32) -> int32 { def test_ifelse(): assert alpha_equal( - parse_expr( + relay.parse( """ if (True) { 0 @@ -271,7 +271,7 @@ def test_ifelse(): @raises(ParseError) def test_ifelse_scope(): - parse_expr( + relay.parse( """ if (True) { let %x = (); @@ -286,7 +286,7 @@ def test_call(): # 0 args constant = relay.Var("constant") assert alpha_equal( - parse_expr( + relay.parse( """ let %constant = fn () { 0 }; %constant() @@ -302,7 +302,7 @@ def test_call(): # 1 arg id_var = relay.Var("id") assert alpha_equal( - parse_expr( + relay.parse( """ let %id = fn (%x) { %x }; %id(1) @@ -318,7 +318,7 @@ def test_call(): # 2 args multiply = relay.Var("multiply") assert alpha_equal( - parse_expr( + relay.parse( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) @@ -338,7 +338,7 @@ def test_call(): # anonymous function assert alpha_equal( - parse_expr( + relay.parse( """ (fn (%x) { %x })(0) """ @@ -359,7 +359,7 @@ def test_call(): # curried function curried_mult = relay.Var("curried_mult") alpha_equal( - parse_expr( + relay.parse( """ let %curried_mult = fn (%x) { @@ -394,7 +394,7 @@ def test_call(): # op alpha_equal( - parse_expr("abs(1)"), + relay.parse("abs(1)"), relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) @@ -402,7 +402,7 @@ def test_call(): def test_incomplete_type(): assert alpha_equal( - parse_expr("let %_ : _ = (); ()"), + relay.parse("let %_ : _ = (); ()"), relay.Let( _, UNIT, @@ -412,7 +412,7 @@ def test_incomplete_type(): def test_builtin_types(): for builtin_type in TYPES: - parse_expr("let %_ : {} = (); ()".format(builtin_type)) + relay.parse("let %_ : {} = (); ()".format(builtin_type)) @nottest def test_call_type(): @@ -420,7 +420,7 @@ def test_call_type(): def test_tensor_type(): assert alpha_equal( - parse_expr("let %_ : Tensor[(), float32] = (); ()"), + relay.parse("let %_ : Tensor[(), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, @@ -429,7 +429,7 @@ def test_tensor_type(): ) assert alpha_equal( - parse_expr("let %_ : Tensor[(1,), float32] = (); ()"), + relay.parse("let %_ : Tensor[(1,), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, @@ -438,7 +438,7 @@ def test_tensor_type(): ) assert alpha_equal( - parse_expr("let %_ : Tensor[(1, 1), float32] = (); ()"), + relay.parse("let %_ : Tensor[(1, 1), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, @@ -448,7 +448,7 @@ def test_tensor_type(): def test_function_type(): assert alpha_equal( - parse_expr( + relay.parse( """ let %_: fn () -> int32 = fn () -> int32 { 0 }; () """ @@ -461,7 +461,7 @@ def test_function_type(): ) assert alpha_equal( - parse_expr( + relay.parse( """ let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """ @@ -474,7 +474,7 @@ def test_function_type(): ) assert alpha_equal( - parse_expr( + relay.parse( """ let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """ @@ -488,7 +488,7 @@ def test_function_type(): def test_tuple_type(): assert alpha_equal( - parse_expr( + relay.parse( """ let %_: () = (); () """), @@ -500,7 +500,7 @@ def test_tuple_type(): ) assert alpha_equal( - parse_expr( + relay.parse( """ let %_: (int32,) = (0,); () """), @@ -512,7 +512,7 @@ def test_tuple_type(): ) assert alpha_equal( - parse_expr( + relay.parse( """ let %_: (int32, int32) = (0, 1); () """), From e33b425ba80e0d18ded61b2c78cbdaa825f62359 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 16:04:50 -0700 Subject: [PATCH 41/64] towards parser script tests --- Jenkinsfile | 1 + tests/scripts/test_relay_parser.sh | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 tests/scripts/test_relay_parser.sh diff --git a/Jenkinsfile b/Jenkinsfile index adc9e12ca74b..b6c23d2c3c49 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -133,6 +133,7 @@ stage('Build') { echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake """ diff --git a/tests/scripts/test_relay_parser.sh b/tests/scripts/test_relay_parser.sh new file mode 100644 index 000000000000..e158af8e60b5 --- /dev/null +++ b/tests/scripts/test_relay_parser.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +export PYTHONPATH=python:topi/python:apps/extension/python +export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} + +rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc + +make cython || exit -1 +make cython3 || exit -1 +TVM_FFI=cython python -m nose -v tests/python/relay/ir_relay_parser.py || exit -1 +TVM_FFI=ctypes python3 -m nose -v tests/python/relay/ir_relay_parser.py || exit -1 From 7d956fb71b4cab4687c48a236415d65bc93d563a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:08:59 -0700 Subject: [PATCH 42/64] complete parser testing compatibility and expose fromtext --- python/tvm/relay/__init__.py | 3 +- python/tvm/relay/_parser.py | 421 +++++++++++++++++ python/tvm/relay/parser.py | 435 +----------------- tests/python/relay/test_ir_parser.py | 144 +++--- ...t_relay_parser.sh => task_relay_parser.sh} | 4 +- 5 files changed, 508 insertions(+), 499 deletions(-) create mode 100644 python/tvm/relay/_parser.py rename tests/scripts/{test_relay_parser.sh => task_relay_parser.sh} (58%) mode change 100644 => 100755 diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 32bdd4ee5f8a..b66132f27775 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -65,5 +65,4 @@ def _debug(*args): pdb.set_trace() # Parser -parse = parser.parse -parse_file = parser.parse_file +fromtext = parser.fromtext diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py new file mode 100644 index 000000000000..9ac0392ac246 --- /dev/null +++ b/python/tvm/relay/_parser.py @@ -0,0 +1,421 @@ +# pylint: disable=invalid-name, unused-import +"""A parser for Relay's text format.""" +class ParseError(Exception): + """Exception type for parse errors.""" + + def __init__(self, message): + # type: (str) -> None + super(ParseError, self).__init__() + self.message = message + +import sys +PYTHON_VERSION = sys.version_info.major + +from collections import deque +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any + +from . import env +from . import expr +from . import ty +from . import op + +try: + if PYTHON_VERSION == 2: + from .grammar.py2.RelayVisitor import RelayVisitor + from .grammar.py2.RelayParser import RelayParser + from .grammar.py2.RelayLexer import RelayLexer + else: + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer +except ImportError: + raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") + +try: + from antlr4 import ParserRuleContext, InputStream, CommonTokenStream + from antlr4.tree.Tree import TerminalNode +except ImportError: + raise ParseError("Couldn't find ANTLR runtime. Try running `pip{} install antlr{}-runtime`.".format(PYTHON_VERSION, PYTHON_VERSION)) + +BINARY_OPS = { + RelayParser.MUL: op.multiply, + RelayParser.DIV: op.divide, + RelayParser.ADD: op.add, + RelayParser.SUB: op.subtract, + RelayParser.LT: op.less, + RelayParser.GT: op.greater, + RelayParser.LE: op.less_equal, + RelayParser.GE: op.greater_equal, + RelayParser.EQ: op.equal, + RelayParser.NE: op.not_equal, +} + +TYPE_PREFIXES = [ + "int", + "uint", + "float", + "bool", +] + +T = TypeVar("T") +Scope = Deque[Tuple[str, T]] +Scopes = Deque[Scope[T]] + +def lookup(scopes, name): + # type: (Scopes[T], str) -> Optional[T] + """Look up `name` in `scopes`.""" + + for scope in scopes: + for key, val in scope: + if key == name: + return val + return None + +# TODO(@jmp): Use https://stackoverflow.com/q/13889941 +# to figure out how to get ANTLR4 to be more unhappy about syntax errors +class ParseTreeToRelayIR(RelayVisitor): + """Parse Relay text format into Relay IR.""" + + def __init__(self): + # type: () -> None + self.env = env.Environment({}) # type: env.Environment + + # Adding an empty scope allows naked lets without pain. + self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeParam] + + super(ParseTreeToRelayIR, self).__init__() + + def enter_var_scope(self): + # type: () -> None + """Enter a new Var scope so it can be popped off later.""" + + self.var_scopes.appendleft(deque()) + + def exit_var_scope(self): + # type: () -> Scope[expr.Var] + """Pop off the current Var scope and return it.""" + + return self.var_scopes.popleft() + + def mk_var(self, name, type_): + # type: (str, ty.Type) -> expr.Var + """Create a new Var and add it to the Var scope.""" + + var = expr.Var(name, type_) + self.var_scopes[0].appendleft((name, var)) + return var + + def enter_type_param_scope(self): + # type: () -> None + """Enter a new TypeParam scope so it can be popped off later.""" + + self.type_param_scopes.appendleft(deque()) + + def exit_type_param_scope(self): + # type: () -> Scope[ty.TypeParam] + """Pop off the current TypeParam scope and return it.""" + + return self.type_param_scopes.popleft() + + def mk_typ(self, name, kind): + # (str, ty.Kind) -> ty.TypeParam + """Create a new TypeParam and add it to the TypeParam scope.""" + + typ = TypeParam(name, kind) + self.type_param_scopes[0].appendleft((name, typ)) + return typ + + def visitTerminal(self, node): + # type: (TerminalNode) -> Union[expr.Expr, int, float] + """Visit lexer tokens that aren't ignored or visited by other functions.""" + + node_type = node.getSymbol().type + node_text = node.getText() + + # variables + if node_type == RelayLexer.GLOBAL_VAR: + return GlobalVar(node_text[1:]) + elif node_type == RelayLexer.LOCAL_VAR: + name = node_text[1:] + var = lookup(self.var_scopes, name) + if var is None: + raise ParseError("Couldn't resolve `{}`.".format(name)) + + return var + + # data types + elif node_type == RelayLexer.INT: + return int(node_text) + elif node_type == RelayLexer.FLOAT: + return float(node_text) + elif node_type == RelayLexer.BOOL_LIT: + if node_text == "True": + return True + elif node_text == "False": + return False + else: + raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) + + else: + raise ParseError("todo: {}".format(node_text)) + + def visit_list(self, ctx_list): + # type: (List[ParserRuleContext]) -> List[Any] + """"Visit a list of contexts.""" + + return [self.visit(ctx) for ctx in ctx_list] + + def getType_(self, ctx): + # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] + """Return a (possibly None) Relay type.""" + + if ctx is None: + return None + + return self.visit(ctx) + + def visitProg(self, ctx): + # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] + if ctx.defn(): + self.visit_list(ctx.defn()) + return self.env + else: + return self.visit(ctx.expr()) + + # Exprs + + def visitOpIdent(self, ctx): + # type: (RelayParser.OpIdentContext) -> op.Op + return op.get(ctx.CNAME().getText()) + + # pass through + def visitParens(self, ctx): + # type: (RelayParser.ParensContext) -> expr.Expr + return self.visit(ctx.expr()) + + # pass through + def visitBody(self, ctx): + # type: (RelayParser.BodyContext) -> expr.Expr + return self.visit(ctx.expr()) + + def visitScalarFloat(self, ctx): + # type: (RelayParser.ScalarFloatContext) -> expr.Constant + return expr.const(self.visit(ctx.FLOAT())) + + def visitScalarInt(self, ctx): + # type: (RelayParser.ScalarIntContext) -> expr.Constant + return expr.const(self.visit(ctx.INT())) + + def visitScalarBool(self, ctx): + # type: (RelayParser.ScalarBoolContext) -> expr.Constant + return expr.const(self.visit(ctx.BOOL_LIT())) + + def visitNeg(self, ctx): + # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] + val = self.visit(ctx.expr()) + if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: + # fold Neg in for scalars + return expr.const(-val.data.asnumpy().item()) + + return op.negative(val) + + def visitTuple(self, ctx): + # type: (RelayParser.TupleContext) -> expr.Tuple + tup = self.visit_list(ctx.expr()) + return expr.Tuple(tup) + + # Currently doesn't support mutable sequencing. + def visitSeq(self, ctx): + # type: (RelayParser.SeqContext) -> expr.Let + """Desugar various sequence constructs to Relay Let nodes.""" + if ctx.MUT() is not None: + raise ParseError("Mutation is currently unsupported.") + + if ctx.var() is None or ctx.var().ident() is None: + # anonymous identity + ident = "_" + type_ = None + else: + local_var = ctx.var().ident().LOCAL_VAR() + if local_var is None: + raise ParseError('Only local ids may be used in `let`s.') + ident = local_var.getText()[1:] + type_ = self.getType_(ctx.var().type_()) + + var = self.mk_var(ident, type_) + + self.enter_var_scope() + value = self.visit(ctx.expr(0)) + self.exit_var_scope() + + body = self.visit(ctx.expr(1)) + + return expr.Let(var, value, body) + + def visitBinOp(self, ctx): + # type: (RelayParser.BinOpContext) -> expr.Call + """Desugar binary operators.""" + arg0, arg1 = self.visit_list(ctx.expr()) + relay_op = BINARY_OPS.get(ctx.op.type) + + if relay_op is None: + raise ParseError("Unimplemented binary op.") + + return relay_op(arg0, arg1) + + def visitVar(self, ctx): + # type: (RelayParser.VarContext) -> expr.Var + ident = ctx.ident().LOCAL_VAR() + + if ident is None: + raise ParseError('Only local ids may be used in params.') + + type_ = self.getType_(ctx.type_()) + + return self.mk_var(ident.getText()[1:], type_) + + def visitVarList(self, ctx): + # type: (RelayParser.VarListContext) -> List[expr.Var] + return self.visit_list(ctx.var()) + + def mk_func(self, ctx): + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function + """Construct a function from either a Func or Defn.""" + + # Enter var scope early to put params in scope. + self.enter_var_scope() + # Capture type params in params. + self.enter_type_param_scope() + var_list = self.visit(ctx.varList()) + ret_type = self.getType_(ctx.type_()) + + type_params = list(self.exit_type_param_scope()) + if type_params: + _, type_params = zip(*type_params) + + body = self.visit(ctx.body()) + self.exit_var_scope() + + return expr.Function(var_list, body, ret_type, type_params) # type: ignore + + def visitFunc(self, ctx): + # type: (RelayParser.FuncContext) -> expr.Function + return self.mk_func(ctx) + + def visitDefn(self, ctx): + # type: (RelayParser.DefnContext) -> None + ident = ctx.ident().GLOBAL_VAR() + if ident is None: + raise ParseError('Only global ids may be used in `def`s.') + ident = expr.GlobalVar(ident.getText()[1:]) + + self.env[ident] = self.mk_func(ctx) + + def visitCall(self, ctx): + # type: (RelayParser.CallContext) -> expr.Call + visited_exprs = self.visit_list(ctx.expr()) + + func = visited_exprs[0] + args = visited_exprs[1:] + + return expr.Call(func, args, None, None) + + def visitIfElse(self, ctx): + # type: (RelayParser.IfElseContext) -> expr.If + """Construct a Relay If node. Creates a new scope for each branch.""" + cond = self.visit(ctx.expr()) + + self.enter_var_scope() + true_branch = self.visit(ctx.body(0)) + self.exit_var_scope() + + self.enter_var_scope() + false_branch = self.visit(ctx.body(1)) + self.exit_var_scope() + + return expr.If(cond, true_branch, false_branch) + + # Types + + # pylint: disable=unused-argument + def visitIncompleteType(self, ctx): + # type (RelayParser.IncompleteTypeContext) -> None: + return None + + def visitIdentType(self, ctx): + # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] + ident_type = ctx.CNAME().getText() + + # look through all type prefixes for a match + for type_prefix in TYPE_PREFIXES: + if ident_type.startswith(type_prefix): + return ty.scalar_type(ident_type) + + raise ParseError("Unknown builtin type: {}".format(ident_type)) + + # def visitCallType(self, ctx): + # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] + # ident_type = ctx.identType().CNAME().getText() + + # args = self.visit_list(ctx.type_()) + + # if not args: + # raise ParseError("Type-level functions must have arguments!") + + # func_type = TYPE_FUNCS.get(ident_type)(args) + + # if func_type is None: + # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) + # else: + # return func_type + + def visitParensShape(self, ctx): + # type: (RelayParser.ParensShapeContext) -> int + return self.visit(ctx.shape()) + + def visitShapeSeq(self, ctx): + # type: (RelayParser.ShapeSeqContext) -> List[int] + return self.visit_list(ctx.shape()) + + def visitTensorType(self, ctx): + # type: (RelayParser.TensorTypeContext) -> ty.TensorType + """Create a simple tensor type. No generics.""" + + shape = self.visit(ctx.shapeSeq()) + dtype = self.visit(ctx.type_()) + + if not isinstance(dtype, ty.TensorType): + raise ParseError("Expected dtype to be a Relay base type.") + + dtype = dtype.dtype + + return ty.TensorType(shape, dtype) + + def visitTupleType(self, ctx): + # type: (RelayParser.TupleTypeContext) -> ty.TupleType + return ty.TupleType(self.visit_list(ctx.type_())) + + def visitFuncType(self, ctx): + # type: (RelayParser.FuncTypeContext) -> ty.FuncType + types = self.visit_list(ctx.type_()) + + arg_types = types[:-1] + ret_type = types[-1] + + return ty.FuncType(arg_types, ret_type, [], None) + +def make_parser(data): + # type: (str) -> RelayParser + """Construct a RelayParser a given data stream.""" + + input_stream = InputStream(data) + lexer = RelayLexer(input_stream) + token_stream = CommonTokenStream(lexer) + return RelayParser(token_stream) + +def fromtext(data): + # type: (str) -> Union[expr.Expr, env.Environment] + """Parse a Relay program.""" + tree = make_parser(data).prog() + return ParseTreeToRelayIR().visit(tree) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 98b5bc078d27..23c2f6b13b08 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,428 +1,13 @@ -# pylint: disable=invalid-name, unused-import -"""A parser for Relay's text format.""" -class ParseError(Exception): - """Exception type for parse errors.""" +from __future__ import absolute_import - def __init__(self, message): - # type: (str) -> None - super(ParseError, self).__init__() - self.message = message +def enabled(): + try: + import tvm.relay._parser + return True + except ImportError: + return False -import sys -PYTHON_VERSION = sys.version_info.major - -from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any - -from . import env -from . import expr -from . import ty -from . import op - -try: - if PYTHON_VERSION == 2: - from .grammar.py2.RelayVisitor import RelayVisitor - from .grammar.py2.RelayParser import RelayParser - from .grammar.py2.RelayLexer import RelayLexer - else: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer -except ImportError: - raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") - -try: - from antlr4 import ParserRuleContext, InputStream, CommonTokenStream - from antlr4.tree.Tree import TerminalNode -except ImportError: - raise ParseError("Couldn't find ANTLR runtime. Try running `pip{} install antlr{}-runtime`.".format(PYTHON_VERSION, PYTHON_VERSION)) - -BINARY_OPS = { - RelayParser.MUL: op.multiply, - RelayParser.DIV: op.divide, - RelayParser.ADD: op.add, - RelayParser.SUB: op.subtract, - RelayParser.LT: op.less, - RelayParser.GT: op.greater, - RelayParser.LE: op.less_equal, - RelayParser.GE: op.greater_equal, - RelayParser.EQ: op.equal, - RelayParser.NE: op.not_equal, -} - -TYPE_PREFIXES = [ - "int", - "uint", - "float", - "bool", -] - -T = TypeVar("T") -Scope = Deque[Tuple[str, T]] -Scopes = Deque[Scope[T]] - -def lookup(scopes, name): - # type: (Scopes[T], str) -> Optional[T] - """Look up `name` in `scopes`.""" - - for scope in scopes: - for key, val in scope: - if key == name: - return val - return None - -# TODO(@jmp): Use https://stackoverflow.com/q/13889941 -# to figure out how to get ANTLR4 to be more unhappy about syntax errors -class ParseTreeToRelayIR(RelayVisitor): - """Parse Relay text format into Relay IR.""" - - def __init__(self): - # type: () -> None - self.env = env.Environment({}) # type: env.Environment - - # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeParam] - - super(ParseTreeToRelayIR, self).__init__() - - def enter_var_scope(self): - # type: () -> None - """Enter a new Var scope so it can be popped off later.""" - - self.var_scopes.appendleft(deque()) - - def exit_var_scope(self): - # type: () -> Scope[expr.Var] - """Pop off the current Var scope and return it.""" - - return self.var_scopes.popleft() - - def mk_var(self, name, type_): - # type: (str, ty.Type) -> expr.Var - """Create a new Var and add it to the Var scope.""" - - var = expr.Var(name, type_) - self.var_scopes[0].appendleft((name, var)) - return var - - def enter_type_param_scope(self): - # type: () -> None - """Enter a new TypeParam scope so it can be popped off later.""" - - self.type_param_scopes.appendleft(deque()) - - def exit_type_param_scope(self): - # type: () -> Scope[ty.TypeParam] - """Pop off the current TypeParam scope and return it.""" - - return self.type_param_scopes.popleft() - - def mk_typ(self, name, kind): - # (str, ty.Kind) -> ty.TypeParam - """Create a new TypeParam and add it to the TypeParam scope.""" - - typ = TypeParam(name, kind) - self.type_param_scopes[0].appendleft((name, typ)) - return typ - - def visitTerminal(self, node): - # type: (TerminalNode) -> Union[expr.Expr, int, float] - """Visit lexer tokens that aren't ignored or visited by other functions.""" - - node_type = node.getSymbol().type - node_text = node.getText() - - # variables - if node_type == RelayLexer.GLOBAL_VAR: - return GlobalVar(node_text[1:]) - elif node_type == RelayLexer.LOCAL_VAR: - name = node_text[1:] - var = lookup(self.var_scopes, name) - if var is None: - raise ParseError("Couldn't resolve `{}`.".format(name)) - - return var - - # data types - elif node_type == RelayLexer.INT: - return int(node_text) - elif node_type == RelayLexer.FLOAT: - return float(node_text) - elif node_type == RelayLexer.BOOL_LIT: - if node_text == "True": - return True - elif node_text == "False": - return False - else: - raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) - - else: - raise ParseError("todo: {}".format(node_text)) - - def visit_list(self, ctx_list): - # type: (List[ParserRuleContext]) -> List[Any] - """"Visit a list of contexts.""" - - return [self.visit(ctx) for ctx in ctx_list] - - def getType_(self, ctx): - # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] - """Return a (possibly None) Relay type.""" - - if ctx is None: - return None - - return self.visit(ctx) - - def visitProg(self, ctx): - # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] - if ctx.defn(): - self.visit_list(ctx.defn()) - return self.env - else: - return self.visit(ctx.expr()) - - # Exprs - - def visitOpIdent(self, ctx): - # type: (RelayParser.OpIdentContext) -> op.Op - return op.get(ctx.CNAME().getText()) - - # pass through - def visitParens(self, ctx): - # type: (RelayParser.ParensContext) -> expr.Expr - return self.visit(ctx.expr()) - - # pass through - def visitBody(self, ctx): - # type: (RelayParser.BodyContext) -> expr.Expr - return self.visit(ctx.expr()) - - def visitScalarFloat(self, ctx): - # type: (RelayParser.ScalarFloatContext) -> expr.Constant - return expr.const(self.visit(ctx.FLOAT())) - - def visitScalarInt(self, ctx): - # type: (RelayParser.ScalarIntContext) -> expr.Constant - return expr.const(self.visit(ctx.INT())) - - def visitScalarBool(self, ctx): - # type: (RelayParser.ScalarBoolContext) -> expr.Constant - return expr.const(self.visit(ctx.BOOL_LIT())) - - def visitNeg(self, ctx): - # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] - val = self.visit(ctx.expr()) - if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: - # fold Neg in for scalars - return expr.const(-val.data.asnumpy().item()) - - return op.negative(val) - - def visitTuple(self, ctx): - # type: (RelayParser.TupleContext) -> expr.Tuple - tup = self.visit_list(ctx.expr()) - return expr.Tuple(tup) - - # Currently doesn't support mutable sequencing. - def visitSeq(self, ctx): - # type: (RelayParser.SeqContext) -> expr.Let - """Desugar various sequence constructs to Relay Let nodes.""" - if ctx.MUT() is not None: - raise ParseError("Mutation is currently unsupported.") - - if ctx.var() is None or ctx.var().ident() is None: - # anonymous identity - ident = "_" - type_ = None - else: - local_var = ctx.var().ident().LOCAL_VAR() - if local_var is None: - raise ParseError('Only local ids may be used in `let`s.') - ident = local_var.getText()[1:] - type_ = self.getType_(ctx.var().type_()) - - var = self.mk_var(ident, type_) - - self.enter_var_scope() - value = self.visit(ctx.expr(0)) - self.exit_var_scope() - - body = self.visit(ctx.expr(1)) - - return expr.Let(var, value, body) - - def visitBinOp(self, ctx): - # type: (RelayParser.BinOpContext) -> expr.Call - """Desugar binary operators.""" - arg0, arg1 = self.visit_list(ctx.expr()) - relay_op = BINARY_OPS.get(ctx.op.type) - - if relay_op is None: - raise ParseError("Unimplemented binary op.") - - return relay_op(arg0, arg1) - - def visitVar(self, ctx): - # type: (RelayParser.VarContext) -> expr.Var - ident = ctx.ident().LOCAL_VAR() - - if ident is None: - raise ParseError('Only local ids may be used in params.') - - type_ = self.getType_(ctx.type_()) - - return self.mk_var(ident.getText()[1:], type_) - - def visitVarList(self, ctx): - # type: (RelayParser.VarListContext) -> List[expr.Var] - return self.visit_list(ctx.var()) - - def mk_func(self, ctx): - # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function - """Construct a function from either a Func or Defn.""" - - # Enter var scope early to put params in scope. - self.enter_var_scope() - # Capture type params in params. - self.enter_type_param_scope() - var_list = self.visit(ctx.varList()) - ret_type = self.getType_(ctx.type_()) - - type_params = list(self.exit_type_param_scope()) - if type_params: - _, type_params = zip(*type_params) - - body = self.visit(ctx.body()) - self.exit_var_scope() - - return expr.Function(var_list, body, ret_type, type_params) # type: ignore - - def visitFunc(self, ctx): - # type: (RelayParser.FuncContext) -> expr.Function - return self.mk_func(ctx) - - def visitDefn(self, ctx): - # type: (RelayParser.DefnContext) -> None - ident = ctx.ident().GLOBAL_VAR() - if ident is None: - raise ParseError('Only global ids may be used in `def`s.') - ident = expr.GlobalVar(ident.getText()[1:]) - - self.env[ident] = self.mk_func(ctx) - - def visitCall(self, ctx): - # type: (RelayParser.CallContext) -> expr.Call - visited_exprs = self.visit_list(ctx.expr()) - - func = visited_exprs[0] - args = visited_exprs[1:] - - return expr.Call(func, args, None, None) - - def visitIfElse(self, ctx): - # type: (RelayParser.IfElseContext) -> expr.If - """Construct a Relay If node. Creates a new scope for each branch.""" - cond = self.visit(ctx.expr()) - - self.enter_var_scope() - true_branch = self.visit(ctx.body(0)) - self.exit_var_scope() - - self.enter_var_scope() - false_branch = self.visit(ctx.body(1)) - self.exit_var_scope() - - return expr.If(cond, true_branch, false_branch) - - # Types - - # pylint: disable=unused-argument - def visitIncompleteType(self, ctx): - # type (RelayParser.IncompleteTypeContext) -> None: - return None - - def visitIdentType(self, ctx): - # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] - ident_type = ctx.CNAME().getText() - - # look through all type prefixes for a match - for type_prefix in TYPE_PREFIXES: - if ident_type.startswith(type_prefix): - return ty.scalar_type(ident_type) - - raise ParseError("Unknown builtin type: {}".format(ident_type)) - - # def visitCallType(self, ctx): - # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] - # ident_type = ctx.identType().CNAME().getText() - - # args = self.visit_list(ctx.type_()) - - # if not args: - # raise ParseError("Type-level functions must have arguments!") - - # func_type = TYPE_FUNCS.get(ident_type)(args) - - # if func_type is None: - # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) - # else: - # return func_type - - def visitParensShape(self, ctx): - # type: (RelayParser.ParensShapeContext) -> int - return self.visit(ctx.shape()) - - def visitShapeSeq(self, ctx): - # type: (RelayParser.ShapeSeqContext) -> List[int] - return self.visit_list(ctx.shape()) - - def visitTensorType(self, ctx): - # type: (RelayParser.TensorTypeContext) -> ty.TensorType - """Create a simple tensor type. No generics.""" - - shape = self.visit(ctx.shapeSeq()) - dtype = self.visit(ctx.type_()) - - if not isinstance(dtype, ty.TensorType): - raise ParseError("Expected dtype to be a Relay base type.") - - dtype = dtype.dtype - - return ty.TensorType(shape, dtype) - - def visitTupleType(self, ctx): - # type: (RelayParser.TupleTypeContext) -> ty.TupleType - return ty.TupleType(self.visit_list(ctx.type_())) - - def visitFuncType(self, ctx): - # type: (RelayParser.FuncTypeContext) -> ty.FuncType - types = self.visit_list(ctx.type_()) - - arg_types = types[:-1] - ret_type = types[-1] - - return ty.FuncType(arg_types, ret_type, [], None) - -def make_parser(data): - # type: (str) -> RelayParser - """Construct a RelayParser a given data stream.""" - - input_stream = InputStream(data) - lexer = RelayLexer(input_stream) - token_stream = CommonTokenStream(lexer) - return RelayParser(token_stream) - -def parse(data): - # type: (str) -> Union[expr.Expr, env.Environment] +def fromtext(data): """Parse a Relay program.""" - tree = make_parser(data).prog() - return ParseTreeToRelayIR().visit(tree) - -def parse_file(path): - # type: (str) -> Union[expr.Expr, env.Environment] - """Parse a Relay program from a file.""" - - with open(path, 'r') as in_file: - return parse(in_file.read()) + from tvm.relay import _parser + return _parser.fromtext(data) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 9aedbf9c1300..85a2a1a10026 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,10 +1,14 @@ import tvm from tvm import relay -from tvm.relay.parser import ParseError +from tvm.relay._parser import ParseError +from tvm.relay.parser import enabled from tvm.relay.ir_pass import alpha_equal from nose.tools import nottest, raises from typing import Union +if not enabled(): + exit() + BINARY_OPS = { "*": relay.multiply, "/": relay.divide, @@ -59,7 +63,7 @@ def is_close(x, y, precision=0.001): def test_comments(): assert alpha_equal( - relay.parse(""" + relay.fromtext(""" // This is a line comment! () """), @@ -67,7 +71,7 @@ def test_comments(): ) assert alpha_equal( - relay.parse(""" + relay.fromtext(""" /* This is a block comment! This is still a block comment! */ @@ -77,82 +81,82 @@ def test_comments(): ) def test_int_literal(): - assert isinstance(relay.parse("1"), relay.Constant) - assert isinstance(relay.parse("1").data, tvm.ndarray.NDArray) + assert isinstance(relay.fromtext("1"), relay.Constant) + assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) - assert get_scalar(relay.parse("1")) == 1 - assert get_scalar(relay.parse("10")) == 10 - assert get_scalar(relay.parse("0")) == 0 - assert get_scalar(relay.parse("-100")) == -100 - assert get_scalar(relay.parse("-05")) == -5 + assert get_scalar(relay.fromtext("1")) == 1 + assert get_scalar(relay.fromtext("10")) == 10 + assert get_scalar(relay.fromtext("0")) == 0 + assert get_scalar(relay.fromtext("-100")) == -100 + assert get_scalar(relay.fromtext("-05")) == -5 def test_float_literal(): - assert get_scalar(relay.parse("1.0")) == 1.0 - assert is_close(get_scalar(relay.parse("1.56667")), 1.56667) - assert get_scalar(relay.parse("0.0")) == 0.0 - assert get_scalar(relay.parse("-10.0")) == -10.0 + assert get_scalar(relay.fromtext("1.0")) == 1.0 + assert is_close(get_scalar(relay.fromtext("1.56667")), 1.56667) + assert get_scalar(relay.fromtext("0.0")) == 0.0 + assert get_scalar(relay.fromtext("-10.0")) == -10.0 # scientific notation - assert is_close(get_scalar(relay.parse("1e-1")), 1e-1) - assert get_scalar(relay.parse("1e+1")) == 1e+1 - assert is_close(get_scalar(relay.parse("1E-1")), 1E-1) - assert get_scalar(relay.parse("1E+1")) == 1E+1 - assert is_close(get_scalar(relay.parse("1.0e-1")), 1.0e-1) - assert get_scalar(relay.parse("1.0e+1")) == 1.0e+1 - assert is_close(get_scalar(relay.parse("1.0E-1")), 1.0E-1) - assert get_scalar(relay.parse("1.0E+1")) == 1.0E+1 + assert is_close(get_scalar(relay.fromtext("1e-1")), 1e-1) + assert get_scalar(relay.fromtext("1e+1")) == 1e+1 + assert is_close(get_scalar(relay.fromtext("1E-1")), 1E-1) + assert get_scalar(relay.fromtext("1E+1")) == 1E+1 + assert is_close(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1) + assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1 + assert is_close(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1) + assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1 def test_bool_literal(): - assert get_scalar(relay.parse("True")) == True - assert get_scalar(relay.parse("False")) == False + assert get_scalar(relay.fromtext("True")) == True + assert get_scalar(relay.fromtext("False")) == False def test_negative(): - assert isinstance(relay.parse("let %x = 1; -%x").body, relay.Call) - assert get_scalar(relay.parse("--10")) == 10 - assert get_scalar(relay.parse("---10")) == -10 + assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call) + assert get_scalar(relay.fromtext("--10")) == 10 + assert get_scalar(relay.fromtext("---10")) == -10 def test_bin_op(): for bin_op in BINARY_OPS.keys(): assert alpha_equal( - relay.parse("1 {} 1".format(bin_op)), + relay.fromtext("1 {} 1".format(bin_op)), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) def test_parens(): - print(relay.parse("1 * 1 + 1")) - assert alpha_equal(relay.parse("1 * 1 + 1"), relay.parse("(1 * 1) + 1")) - assert not alpha_equal(relay.parse("1 * 1 + 1"), relay.parse("1 * (1 + 1)")) + print(relay.fromtext("1 * 1 + 1")) + assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1")) + assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)")) def test_op_assoc(): - assert alpha_equal(relay.parse("1 * 1 + 1 < 1 == 1"), relay.parse("(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(relay.parse("1 == 1 < 1 + 1 * 1"), relay.parse("1 == (1 < (1 + (1 * 1)))")) + assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))")) @nottest def test_vars(): # temp vars won't work b/c they start with a digit # # temp var - # temp_var = relay.parse("%1") + # temp_var = relay.fromtext("%1") # assert isinstance(temp_var, relay.Var) # assert temp_var.name == "1" # var - var = relay.parse("let %foo = (); %foo") + var = relay.fromtext("let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var - global_var = relay.parse("@foo") + global_var = relay.fromtext("@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id - op = relay.parse("foo") + op = relay.fromtext("foo") assert isinstance(op, relay.Op) assert op.name == "foo" def test_let(): assert alpha_equal( - relay.parse("let %x = 1; ()"), + relay.fromtext("let %x = 1; ()"), relay.Let( X, relay.const(1), @@ -162,7 +166,7 @@ def test_let(): def test_seq(): assert alpha_equal( - relay.parse("(); ()"), + relay.fromtext("(); ()"), relay.Let( _, UNIT, @@ -170,7 +174,7 @@ def test_seq(): ) assert alpha_equal( - relay.parse("let %_ = { 1 }; ()"), + relay.fromtext("let %_ = { 1 }; ()"), relay.Let( X, relay.const(1), @@ -180,25 +184,25 @@ def test_seq(): @raises(ParseError) def test_let_global_var(): - relay.parse("let @x = 1; ()") + relay.fromtext("let @x = 1; ()") @raises(ParseError) def test_let_op(): - relay.parse("let x = 1; ()") + relay.fromtext("let x = 1; ()") def test_tuple(): - assert alpha_equal(relay.parse("()"), relay.Tuple([])) + assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) - assert alpha_equal(relay.parse("(0,)"), relay.Tuple([relay.const(0)])) + assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) - assert alpha_equal(relay.parse("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) + assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) - assert alpha_equal(relay.parse("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) def test_func(): # 0 args assert alpha_equal( - relay.parse("fn () { 0 }"), + relay.fromtext("fn () { 0 }"), relay.Function( [], relay.const(0), @@ -209,7 +213,7 @@ def test_func(): # 1 arg assert alpha_equal( - relay.parse("fn (%x) { %x }"), + relay.fromtext("fn (%x) { %x }"), relay.Function( [X], X, @@ -220,7 +224,7 @@ def test_func(): # 2 args assert alpha_equal( - relay.parse("fn (%x, %y) { %x + %y }"), + relay.fromtext("fn (%x, %y) { %x + %y }"), relay.Function( [X, Y], relay.add(X, Y), @@ -231,7 +235,7 @@ def test_func(): # annotations assert alpha_equal( - relay.parse("fn (%x: int32) -> int32 { %x }"), + relay.fromtext("fn (%x: int32) -> int32 { %x }"), relay.Function( [X_ANNO], X_ANNO, @@ -243,7 +247,7 @@ def test_func(): # TODO(@jmp): Crashes if %x isn't annnotated. # @nottest def test_defn(): - id_defn = relay.parse( + id_defn = relay.fromtext( """ def @id(%x: int32) -> int32 { %x @@ -253,7 +257,7 @@ def @id(%x: int32) -> int32 { def test_ifelse(): assert alpha_equal( - relay.parse( + relay.fromtext( """ if (True) { 0 @@ -271,7 +275,7 @@ def test_ifelse(): @raises(ParseError) def test_ifelse_scope(): - relay.parse( + relay.fromtext( """ if (True) { let %x = (); @@ -286,7 +290,7 @@ def test_call(): # 0 args constant = relay.Var("constant") assert alpha_equal( - relay.parse( + relay.fromtext( """ let %constant = fn () { 0 }; %constant() @@ -302,7 +306,7 @@ def test_call(): # 1 arg id_var = relay.Var("id") assert alpha_equal( - relay.parse( + relay.fromtext( """ let %id = fn (%x) { %x }; %id(1) @@ -318,7 +322,7 @@ def test_call(): # 2 args multiply = relay.Var("multiply") assert alpha_equal( - relay.parse( + relay.fromtext( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) @@ -338,7 +342,7 @@ def test_call(): # anonymous function assert alpha_equal( - relay.parse( + relay.fromtext( """ (fn (%x) { %x })(0) """ @@ -359,7 +363,7 @@ def test_call(): # curried function curried_mult = relay.Var("curried_mult") alpha_equal( - relay.parse( + relay.fromtext( """ let %curried_mult = fn (%x) { @@ -394,7 +398,7 @@ def test_call(): # op alpha_equal( - relay.parse("abs(1)"), + relay.fromtext("abs(1)"), relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) @@ -402,7 +406,7 @@ def test_call(): def test_incomplete_type(): assert alpha_equal( - relay.parse("let %_ : _ = (); ()"), + relay.fromtext("let %_ : _ = (); ()"), relay.Let( _, UNIT, @@ -412,7 +416,7 @@ def test_incomplete_type(): def test_builtin_types(): for builtin_type in TYPES: - relay.parse("let %_ : {} = (); ()".format(builtin_type)) + relay.fromtext("let %_ : {} = (); ()".format(builtin_type)) @nottest def test_call_type(): @@ -420,7 +424,7 @@ def test_call_type(): def test_tensor_type(): assert alpha_equal( - relay.parse("let %_ : Tensor[(), float32] = (); ()"), + relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, @@ -429,7 +433,7 @@ def test_tensor_type(): ) assert alpha_equal( - relay.parse("let %_ : Tensor[(1,), float32] = (); ()"), + relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, @@ -438,7 +442,7 @@ def test_tensor_type(): ) assert alpha_equal( - relay.parse("let %_ : Tensor[(1, 1), float32] = (); ()"), + relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, @@ -448,7 +452,7 @@ def test_tensor_type(): def test_function_type(): assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: fn () -> int32 = fn () -> int32 { 0 }; () """ @@ -461,7 +465,7 @@ def test_function_type(): ) assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """ @@ -474,7 +478,7 @@ def test_function_type(): ) assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """ @@ -488,7 +492,7 @@ def test_function_type(): def test_tuple_type(): assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: () = (); () """), @@ -500,7 +504,7 @@ def test_tuple_type(): ) assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: (int32,) = (0,); () """), @@ -512,7 +516,7 @@ def test_tuple_type(): ) assert alpha_equal( - relay.parse( + relay.fromtext( """ let %_: (int32, int32) = (0, 1); () """), diff --git a/tests/scripts/test_relay_parser.sh b/tests/scripts/task_relay_parser.sh old mode 100644 new mode 100755 similarity index 58% rename from tests/scripts/test_relay_parser.sh rename to tests/scripts/task_relay_parser.sh index e158af8e60b5..80a02419cde0 --- a/tests/scripts/test_relay_parser.sh +++ b/tests/scripts/task_relay_parser.sh @@ -7,5 +7,5 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc make cython || exit -1 make cython3 || exit -1 -TVM_FFI=cython python -m nose -v tests/python/relay/ir_relay_parser.py || exit -1 -TVM_FFI=ctypes python3 -m nose -v tests/python/relay/ir_relay_parser.py || exit -1 +TVM_FFI=cython python -m nose -v tests/python/relay/test_ir_parser.py || exit -1 +TVM_FFI=ctypes python3 -m nose -v tests/python/relay/test_ir_parser.py || exit -1 \ No newline at end of file From a396cf73380f24044cf4cb6889a62e407caa3b1c Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:12:44 -0700 Subject: [PATCH 43/64] fix bad imports --- python/tvm/relay/_parser.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 9ac0392ac246..bc9233cf3ca7 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -82,7 +82,7 @@ def __init__(self): # Adding an empty scope allows naked lets without pain. self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeParam] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] super(ParseTreeToRelayIR, self).__init__() @@ -108,21 +108,21 @@ def mk_var(self, name, type_): def enter_type_param_scope(self): # type: () -> None - """Enter a new TypeParam scope so it can be popped off later.""" + """Enter a new TypeVar scope so it can be popped off later.""" self.type_param_scopes.appendleft(deque()) def exit_type_param_scope(self): - # type: () -> Scope[ty.TypeParam] - """Pop off the current TypeParam scope and return it.""" + # type: () -> Scope[ty.TypeVar] + """Pop off the current TypeVar scope and return it.""" return self.type_param_scopes.popleft() def mk_typ(self, name, kind): - # (str, ty.Kind) -> ty.TypeParam - """Create a new TypeParam and add it to the TypeParam scope.""" + # (str, ty.Kind) -> ty.TypeVar + """Create a new TypeVar and add it to the TypeVar scope.""" - typ = TypeParam(name, kind) + typ = ty.TypeVar(name, kind) self.type_param_scopes[0].appendleft((name, typ)) return typ @@ -135,7 +135,7 @@ def visitTerminal(self, node): # variables if node_type == RelayLexer.GLOBAL_VAR: - return GlobalVar(node_text[1:]) + return expr.GlobalVar(node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: name = node_text[1:] var = lookup(self.var_scopes, name) From e04dfa516ea4c1f8de60f9db5fcc1d9484fe7dce Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:14:29 -0700 Subject: [PATCH 44/64] ImportError -> Exception --- python/tvm/relay/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 23c2f6b13b08..516dfc1066f3 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -4,7 +4,7 @@ def enabled(): try: import tvm.relay._parser return True - except ImportError: + except Exception: return False def fromtext(data): From 3fd20b10f615821bc863d90f7577eefded3b2780 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:24:18 -0700 Subject: [PATCH 45/64] linting --- python/tvm/relay/_parser.py | 2 ++ python/tvm/relay/parser.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index bc9233cf3ca7..9893bb05aecb 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -8,6 +8,8 @@ def __init__(self, message): super(ParseError, self).__init__() self.message = message +from __future__ import absolute_import + import sys PYTHON_VERSION = sys.version_info.major diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 516dfc1066f3..009a721f05ea 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,13 +1,16 @@ from __future__ import absolute_import def enabled(): + """Is the parser enabled/Can we import the parser?""" try: + # pylint: disable=unused-variable import tvm.relay._parser return True + # pylint: disable=broad-except except Exception: return False def fromtext(data): """Parse a Relay program.""" - from tvm.relay import _parser + from tvm.relay import _parser return _parser.fromtext(data) From 9f03684d90e04ef07cd98689a6c6fce03cfa711e Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:27:37 -0700 Subject: [PATCH 46/64] linting --- python/tvm/relay/_parser.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 9893bb05aecb..105f0f124f0f 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -1,13 +1,6 @@ + # pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" -class ParseError(Exception): - """Exception type for parse errors.""" - - def __init__(self, message): - # type: (str) -> None - super(ParseError, self).__init__() - self.message = message - from __future__ import absolute_import import sys @@ -21,6 +14,14 @@ def __init__(self, message): from . import ty from . import op +class ParseError(Exception): + """Exception type for parse errors.""" + + def __init__(self, message): + # type: (str) -> None + super(ParseError, self).__init__() + self.message = message + try: if PYTHON_VERSION == 2: from .grammar.py2.RelayVisitor import RelayVisitor @@ -37,7 +38,9 @@ def __init__(self, message): from antlr4 import ParserRuleContext, InputStream, CommonTokenStream from antlr4.tree.Tree import TerminalNode except ImportError: - raise ParseError("Couldn't find ANTLR runtime. Try running `pip{} install antlr{}-runtime`.".format(PYTHON_VERSION, PYTHON_VERSION)) + raise ParseError("Couldn't find ANTLR runtime." + + "Try running `pip{} install antlr4-python{}-runtime`." + .format(PYTHON_VERSION, PYTHON_VERSION)) BINARY_OPS = { RelayParser.MUL: op.multiply, @@ -182,8 +185,8 @@ def visitProg(self, ctx): if ctx.defn(): self.visit_list(ctx.defn()) return self.env - else: - return self.visit(ctx.expr()) + + return self.visit(ctx.expr()) # Exprs From 98e44c554fad3b77460f88f56eeb0eef7efb7005 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:32:39 -0700 Subject: [PATCH 47/64] linting --- python/tvm/relay/_parser.py | 4 ++-- python/tvm/relay/parser.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 105f0f124f0f..d4a23e5e147b 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -4,7 +4,6 @@ from __future__ import absolute_import import sys -PYTHON_VERSION = sys.version_info.major from collections import deque from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any @@ -22,6 +21,7 @@ def __init__(self, message): super(ParseError, self).__init__() self.message = message +PYTHON_VERSION = sys.version_info.major try: if PYTHON_VERSION == 2: from .grammar.py2.RelayVisitor import RelayVisitor @@ -185,7 +185,7 @@ def visitProg(self, ctx): if ctx.defn(): self.visit_list(ctx.defn()) return self.env - + return self.visit(ctx.expr()) # Exprs diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 009a721f05ea..1d3ce706029d 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,3 +1,4 @@ +"""A parser for Relay's text format.""" from __future__ import absolute_import def enabled(): From ac74b91496785a6f0c0a0015e8a618590c12adb1 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 17:46:53 -0700 Subject: [PATCH 48/64] ci bump --- python/tvm/relay/_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index d4a23e5e147b..80264d60f8c4 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -413,7 +413,6 @@ def visitFuncType(self, ctx): def make_parser(data): # type: (str) -> RelayParser """Construct a RelayParser a given data stream.""" - input_stream = InputStream(data) lexer = RelayLexer(input_stream) token_stream = CommonTokenStream(lexer) From fe8a7a6d9c0beda4962fb336586e418e5dc338a2 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 18:11:28 -0700 Subject: [PATCH 49/64] exit earlier --- tests/python/relay/test_ir_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 85a2a1a10026..83d0a9da6527 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,3 +1,6 @@ +if not enabled(): + exit() + import tvm from tvm import relay from tvm.relay._parser import ParseError @@ -6,9 +9,6 @@ from nose.tools import nottest, raises from typing import Union -if not enabled(): - exit() - BINARY_OPS = { "*": relay.multiply, "/": relay.divide, From bac7c19345d44a8ea5bcd5d66750e2a912309266 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 18:25:41 -0700 Subject: [PATCH 50/64] dependencies --- tests/python/relay/test_ir_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 83d0a9da6527..68564e9f6c30 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -1,10 +1,10 @@ +from tvm.relay.parser import enabled if not enabled(): exit() import tvm from tvm import relay from tvm.relay._parser import ParseError -from tvm.relay.parser import enabled from tvm.relay.ir_pass import alpha_equal from nose.tools import nottest, raises from typing import Union From fcd220b025e3e972ac9c618e0bc1c50796a924be Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 1 Nov 2018 18:44:03 -0700 Subject: [PATCH 51/64] delete separate script --- tests/scripts/task_relay_parser.sh | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100755 tests/scripts/task_relay_parser.sh diff --git a/tests/scripts/task_relay_parser.sh b/tests/scripts/task_relay_parser.sh deleted file mode 100755 index 80a02419cde0..000000000000 --- a/tests/scripts/task_relay_parser.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -export PYTHONPATH=python:topi/python:apps/extension/python -export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} - -rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc - -make cython || exit -1 -make cython3 || exit -1 -TVM_FFI=cython python -m nose -v tests/python/relay/test_ir_parser.py || exit -1 -TVM_FFI=ctypes python3 -m nose -v tests/python/relay/test_ir_parser.py || exit -1 \ No newline at end of file From 0cf759412db5a1d3fbd0307c7146b5547dc324c7 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 2 Nov 2018 11:52:05 -0700 Subject: [PATCH 52/64] failing test. should fail ci --- tests/python/relay/test_ir_parser.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 68564e9f6c30..4c7885abe6cb 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -80,6 +80,16 @@ def test_comments(): UNIT ) + assert alpha_equal( + relay.fromtext(""" + /* This is a block comment! + This is still a block comment! + */ + () + """), + X + ) + def test_int_literal(): assert isinstance(relay.fromtext("1"), relay.Constant) assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) From 30bbbfc60432edfdfabffd28f239399de3709aeb Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 2 Nov 2018 12:46:07 -0700 Subject: [PATCH 53/64] simplify failing test. rework import --- python/tvm/relay/parser.py | 2 +- tests/python/relay/test_ir_parser.py | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 1d3ce706029d..51200343f147 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -5,7 +5,7 @@ def enabled(): """Is the parser enabled/Can we import the parser?""" try: # pylint: disable=unused-variable - import tvm.relay._parser + from tvm.relay import _parser return True # pylint: disable=broad-except except Exception: diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 4c7885abe6cb..1a9a9b7eebfc 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -80,15 +80,7 @@ def test_comments(): UNIT ) - assert alpha_equal( - relay.fromtext(""" - /* This is a block comment! - This is still a block comment! - */ - () - """), - X - ) + assert False def test_int_literal(): assert isinstance(relay.fromtext("1"), relay.Constant) From 4eaca59585c6082ecd911a5e6686a3bf7b738c01 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 2 Nov 2018 15:44:49 -0700 Subject: [PATCH 54/64] switch USE_ANTLR from cpu to gpu --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b6c23d2c3c49..407bcb151664 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -98,6 +98,7 @@ stage('Build') { echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake @@ -133,7 +134,6 @@ stage('Build') { echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake - echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake """ From 3199d1e75c23a519d559c22e2d72dcf0d4f0405f Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Mon, 19 Nov 2018 17:10:00 -0800 Subject: [PATCH 55/64] rebase --- python/tvm/relay/_parser.py | 8 ++++---- python/tvm/relay/expr.py | 2 -- tests/python/relay/test_ir_parser.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 80264d60f8c4..f64c635dd4ff 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -8,7 +8,7 @@ from collections import deque from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any -from . import env +from . import module from . import expr from . import ty from . import op @@ -83,7 +83,7 @@ class ParseTreeToRelayIR(RelayVisitor): def __init__(self): # type: () -> None - self.env = env.Environment({}) # type: env.Environment + self.module = module.Module({}) # type: module.Module # Adding an empty scope allows naked lets without pain. self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] @@ -184,7 +184,7 @@ def visitProg(self, ctx): # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] if ctx.defn(): self.visit_list(ctx.defn()) - return self.env + return self.module return self.visit(ctx.expr()) @@ -315,7 +315,7 @@ def visitDefn(self, ctx): raise ParseError('Only global ids may be used in `def`s.') ident = expr.GlobalVar(ident.getText()[1:]) - self.env[ident] = self.mk_func(ctx) + self.module[ident] = self.mk_func(ctx) def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index b6ba2bee6c1c..4725c0a7a07d 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -622,5 +622,3 @@ def bind(expr, binds): The expression or function after binding. """ return _expr.Bind(expr, binds) - -pretty_print = _expr._pretty_print diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 1a9a9b7eebfc..f4cc5639dd72 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -255,7 +255,7 @@ def @id(%x: int32) -> int32 { %x } """) - assert isinstance(id_defn, relay.Environment) + assert isinstance(id_defn, relay.Module) def test_ifelse(): assert alpha_equal( From 081a90772279cd8f6affad16cdf7df6eb4a0723a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Tue, 27 Nov 2018 19:09:53 -0800 Subject: [PATCH 56/64] fix ci (please) --- Jenkinsfile | 2 +- cmake/modules/ANTLR.cmake | 43 +++++++++++++------------- docker/Dockerfile.ci_gpu | 3 -- docker/install/ubuntu_install_antlr.sh | 2 -- 4 files changed, 22 insertions(+), 28 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 407bcb151664..b6c23d2c3c49 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -98,7 +98,6 @@ stage('Build') { echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake - echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake @@ -134,6 +133,7 @@ stage('Build') { echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake """ diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index bfc8403a373c..c4fa24d52ecd 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,29 +1,28 @@ -find_program(ANTLR4 antlr4) - if(USE_ANTLR) - find_program(ANTLR4 antlr4) - if(NOT ANTLR4) - message(FATAL_ERROR "Can't find ANTLR4!") - endif() + if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) + set(ANTLR4 "${JAVA_HOME}/bin/java -jar /usr/local/lib/antlr-4.7.1-complete.jar") - set(RELAY_PARSER_DIR - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) + set(RELAY_PARSER_DIR + ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) - set(RELAY_PARSER - ${RELAY_PARSER_DIR}/py2/RelayVisitor.py - ${RELAY_PARSER_DIR}/py2/RelayParser.py - ${RELAY_PARSER_DIR}/py2/RelayLexer.py + set(RELAY_PARSER + ${RELAY_PARSER_DIR}/py2/RelayVisitor.py + ${RELAY_PARSER_DIR}/py2/RelayParser.py + ${RELAY_PARSER_DIR}/py2/RelayLexer.py - ${RELAY_PARSER_DIR}/py3/RelayVisitor.py - ${RELAY_PARSER_DIR}/py3/RelayParser.py - ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + ${RELAY_PARSER_DIR}/py3/RelayVisitor.py + ${RELAY_PARSER_DIR}/py3/RelayParser.py + ${RELAY_PARSER_DIR}/py3/RelayLexer.py) - # Generate ANTLR grammar for parsing. - add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND antlr4 -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 - COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 - DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 - WORKING_DIRECTORY ${RELAY_PARSER_DIR}) + # Generate ANTLR grammar for parsing. + add_custom_command(OUTPUT ${RELAY_PARSER} + COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 + WORKING_DIRECTORY ${RELAY_PARSER_DIR}) - add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) + add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) + else() + message(FATAL_ERROR "Can't find ANTLR4!") + endif() endif(USE_ANTLR) diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 708331d3d61a..c177ef9d420a 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -66,9 +66,6 @@ RUN bash /install/ubuntu_install_vulkan.sh COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - # Environment variables ENV PATH=/usr/local/nvidia/bin:${PATH} ENV PATH=/usr/local/cuda/bin:${PATH} diff --git a/docker/install/ubuntu_install_antlr.sh b/docker/install/ubuntu_install_antlr.sh index f1066c4220d4..d2f2d6a8c48f 100644 --- a/docker/install/ubuntu_install_antlr.sh +++ b/docker/install/ubuntu_install_antlr.sh @@ -1,5 +1,3 @@ cd /usr/local/lib wget https://www.antlr.org/download/antlr-4.7.1-complete.jar cd - - -alias antlr4='java -jar /usr/local/lib/antlr-4.7.1-complete.jar' From f837861190d72087fa02d04d8b005a0208f0b881 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 29 Nov 2018 20:59:53 -0800 Subject: [PATCH 57/64] source /etc/profile to add JAVA_HOME --- docker/install/ubuntu_install_java.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/install/ubuntu_install_java.sh b/docker/install/ubuntu_install_java.sh index 462edc491627..8b90cbe4b676 100644 --- a/docker/install/ubuntu_install_java.sh +++ b/docker/install/ubuntu_install_java.sh @@ -2,3 +2,5 @@ set -o errexit -o nounset apt-get update && apt-get install -y openjdk-8-jdk maven test -d "/usr/lib/jvm/java-8-openjdk-amd64/jre" echo "export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre" >> /etc/profile + +source /etc/profile From 7eefbf33043cc272997e66714699857fb22d2011 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Thu, 29 Nov 2018 21:31:43 -0800 Subject: [PATCH 58/64] revert install_java change. source /etc/profile during make function --- Jenkinsfile | 2 ++ docker/install/ubuntu_install_java.sh | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index b6c23d2c3c49..7761db5d13dd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -51,9 +51,11 @@ stage("Sanity Check") { def make(docker_type, path, make_flag) { timeout(time: max_time, unit: 'MINUTES') { try { + sh "${docker_run} ${docker_type} source /etc/profile" sh "${docker_run} ${docker_type} ./tests/scripts/task_build.sh ${path} ${make_flag}" } catch (exc) { echo 'Incremental compilation failed. Fall back to build from scratch' + sh "${docker_run} ${docker_type} source /etc/profile" sh "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}" sh "${docker_run} ${docker_type} ./tests/scripts/task_build.sh ${path} ${make_flag}" } diff --git a/docker/install/ubuntu_install_java.sh b/docker/install/ubuntu_install_java.sh index 8b90cbe4b676..462edc491627 100644 --- a/docker/install/ubuntu_install_java.sh +++ b/docker/install/ubuntu_install_java.sh @@ -2,5 +2,3 @@ set -o errexit -o nounset apt-get update && apt-get install -y openjdk-8-jdk maven test -d "/usr/lib/jvm/java-8-openjdk-amd64/jre" echo "export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre" >> /etc/profile - -source /etc/profile From f3abc5d60dd49615f3c681d0f53d64661db427cc Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 30 Nov 2018 11:51:19 -0800 Subject: [PATCH 59/64] enable antlr on gpu --- Jenkinsfile | 1 + docker/Dockerfile.ci_gpu | 3 +++ 2 files changed, 4 insertions(+) diff --git a/Jenkinsfile b/Jenkinsfile index 7761db5d13dd..822d145a7d3c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -100,6 +100,7 @@ stage('Build') { echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index c177ef9d420a..708331d3d61a 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -66,6 +66,9 @@ RUN bash /install/ubuntu_install_vulkan.sh COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh +COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh +RUN bash /install/ubuntu_install_antlr.sh + # Environment variables ENV PATH=/usr/local/nvidia/bin:${PATH} ENV PATH=/usr/local/cuda/bin:${PATH} From 3e2244f0b121c7fb3f44627c1f76bc8218e15f5b Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 30 Nov 2018 14:06:04 -0800 Subject: [PATCH 60/64] trigger parser tests --- tests/python/relay/test_ir_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index f4cc5639dd72..6b8b29a93474 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -80,8 +80,6 @@ def test_comments(): UNIT ) - assert False - def test_int_literal(): assert isinstance(relay.fromtext("1"), relay.Constant) assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) @@ -92,6 +90,8 @@ def test_int_literal(): assert get_scalar(relay.fromtext("-100")) == -100 assert get_scalar(relay.fromtext("-05")) == -5 + assert False + def test_float_literal(): assert get_scalar(relay.fromtext("1.0")) == 1.0 assert is_close(get_scalar(relay.fromtext("1.56667")), 1.56667) From ce725024783f02df778736e30d19e94673cd8a94 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 30 Nov 2018 18:28:40 -0800 Subject: [PATCH 61/64] revert Dockerfile to master. remove parser tests from this pr --- docker/Dockerfile.ci_cpu | 7 - tests/python/relay/test_ir_parser.py | 530 --------------------------- 2 files changed, 537 deletions(-) delete mode 100644 tests/python/relay/test_ir_parser.py diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 11a77adbfdde..e6e2dd7a37b0 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -40,10 +40,3 @@ COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin - -# ANTLR deps -COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh -RUN bash /install/ubuntu_install_java.sh - -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py deleted file mode 100644 index 6b8b29a93474..000000000000 --- a/tests/python/relay/test_ir_parser.py +++ /dev/null @@ -1,530 +0,0 @@ -from tvm.relay.parser import enabled -if not enabled(): - exit() - -import tvm -from tvm import relay -from tvm.relay._parser import ParseError -from tvm.relay.ir_pass import alpha_equal -from nose.tools import nottest, raises -from typing import Union - -BINARY_OPS = { - "*": relay.multiply, - "/": relay.divide, - "+": relay.add, - "-": relay.subtract, - "<": relay.less, - ">": relay.greater, - "<=": relay.less_equal, - ">=": relay.greater_equal, - "==": relay.equal, - "!=": relay.not_equal, -} - -TYPES = { - "int8", - "int16", - "int32", - "int64", - - "uint8", - "uint16", - "uint32", - "uint64", - - "float16", - "float32", - "float64", - - "bool", - - "int8x4", - "uint1x4", - "float16x4", -} - -def get_scalar(x): - # type: (relay.Constant) -> (Union[float, int, bool]) - return x.data.asnumpy().item() - -def is_close(x, y, precision=0.001): - return x - y < precision and y - x < precision - -int32 = relay.scalar_type("int32") - -_ = relay.Var("_") -X = relay.Var("x") -Y = relay.Var("y") -X_ANNO = relay.Var("x", int32) -Y_ANNO = relay.Var("y", int32) - -UNIT = relay.Tuple([]) - -def test_comments(): - assert alpha_equal( - relay.fromtext(""" - // This is a line comment! - () - """), - UNIT - ) - - assert alpha_equal( - relay.fromtext(""" - /* This is a block comment! - This is still a block comment! - */ - () - """), - UNIT - ) - -def test_int_literal(): - assert isinstance(relay.fromtext("1"), relay.Constant) - assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) - - assert get_scalar(relay.fromtext("1")) == 1 - assert get_scalar(relay.fromtext("10")) == 10 - assert get_scalar(relay.fromtext("0")) == 0 - assert get_scalar(relay.fromtext("-100")) == -100 - assert get_scalar(relay.fromtext("-05")) == -5 - - assert False - -def test_float_literal(): - assert get_scalar(relay.fromtext("1.0")) == 1.0 - assert is_close(get_scalar(relay.fromtext("1.56667")), 1.56667) - assert get_scalar(relay.fromtext("0.0")) == 0.0 - assert get_scalar(relay.fromtext("-10.0")) == -10.0 - - # scientific notation - assert is_close(get_scalar(relay.fromtext("1e-1")), 1e-1) - assert get_scalar(relay.fromtext("1e+1")) == 1e+1 - assert is_close(get_scalar(relay.fromtext("1E-1")), 1E-1) - assert get_scalar(relay.fromtext("1E+1")) == 1E+1 - assert is_close(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1) - assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1 - assert is_close(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1) - assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1 - -def test_bool_literal(): - assert get_scalar(relay.fromtext("True")) == True - assert get_scalar(relay.fromtext("False")) == False - -def test_negative(): - assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call) - assert get_scalar(relay.fromtext("--10")) == 10 - assert get_scalar(relay.fromtext("---10")) == -10 - -def test_bin_op(): - for bin_op in BINARY_OPS.keys(): - assert alpha_equal( - relay.fromtext("1 {} 1".format(bin_op)), - BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) - ) - -def test_parens(): - print(relay.fromtext("1 * 1 + 1")) - assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1")) - assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)")) - -def test_op_assoc(): - assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))")) - -@nottest -def test_vars(): - # temp vars won't work b/c they start with a digit - # # temp var - # temp_var = relay.fromtext("%1") - # assert isinstance(temp_var, relay.Var) - # assert temp_var.name == "1" - - # var - var = relay.fromtext("let %foo = (); %foo") - assert isinstance(var.body, relay.Var) - assert var.body.name_hint == "foo" - - # global var - global_var = relay.fromtext("@foo") - assert isinstance(global_var, relay.GlobalVar) - assert global_var.name_hint == "foo" - - # operator id - op = relay.fromtext("foo") - assert isinstance(op, relay.Op) - assert op.name == "foo" - -def test_let(): - assert alpha_equal( - relay.fromtext("let %x = 1; ()"), - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - -def test_seq(): - assert alpha_equal( - relay.fromtext("(); ()"), - relay.Let( - _, - UNIT, - UNIT) - ) - - assert alpha_equal( - relay.fromtext("let %_ = { 1 }; ()"), - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - -@raises(ParseError) -def test_let_global_var(): - relay.fromtext("let @x = 1; ()") - -@raises(ParseError) -def test_let_op(): - relay.fromtext("let x = 1; ()") - -def test_tuple(): - assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) - - assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) - - assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) - - assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) - -def test_func(): - # 0 args - assert alpha_equal( - relay.fromtext("fn () { 0 }"), - relay.Function( - [], - relay.const(0), - None, - [] - ) - ) - - # 1 arg - assert alpha_equal( - relay.fromtext("fn (%x) { %x }"), - relay.Function( - [X], - X, - None, - [] - ) - ) - - # 2 args - assert alpha_equal( - relay.fromtext("fn (%x, %y) { %x + %y }"), - relay.Function( - [X, Y], - relay.add(X, Y), - None, - [] - ) - ) - - # annotations - assert alpha_equal( - relay.fromtext("fn (%x: int32) -> int32 { %x }"), - relay.Function( - [X_ANNO], - X_ANNO, - int32, - [] - ) - ) - -# TODO(@jmp): Crashes if %x isn't annnotated. -# @nottest -def test_defn(): - id_defn = relay.fromtext( - """ - def @id(%x: int32) -> int32 { - %x - } - """) - assert isinstance(id_defn, relay.Module) - -def test_ifelse(): - assert alpha_equal( - relay.fromtext( - """ - if (True) { - 0 - } else { - 1 - } - """ - ), - relay.If( - relay.const(True), - relay.const(0), - relay.const(1) - ) - ) - -@raises(ParseError) -def test_ifelse_scope(): - relay.fromtext( - """ - if (True) { - let %x = (); - () - } else { - %x - } - """ - ) - -def test_call(): - # 0 args - constant = relay.Var("constant") - assert alpha_equal( - relay.fromtext( - """ - let %constant = fn () { 0 }; - %constant() - """ - ), - relay.Let( - constant, - relay.Function([], relay.const(0), None, []), - relay.Call(constant, [], None, None) - ) - ) - - # 1 arg - id_var = relay.Var("id") - assert alpha_equal( - relay.fromtext( - """ - let %id = fn (%x) { %x }; - %id(1) - """ - ), - relay.Let( - id_var, - relay.Function([X], X, None, []), - relay.Call(id_var, [relay.const(1)], None, None) - ) - ) - - # 2 args - multiply = relay.Var("multiply") - assert alpha_equal( - relay.fromtext( - """ - let %multiply = fn (%x, %y) { %x * %y }; - %multiply(0, 0) - """ - ), - relay.Let( - multiply, - relay.Function( - [X, Y], - relay.multiply(X, Y), - None, - [] - ), - relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) - ) - ) - - # anonymous function - assert alpha_equal( - relay.fromtext( - """ - (fn (%x) { %x })(0) - """ - ), - relay.Call( - relay.Function( - [X], - X, - None, - [] - ), - [relay.const(0)], - None, - None - ) - ) - - # curried function - curried_mult = relay.Var("curried_mult") - alpha_equal( - relay.fromtext( - """ - let %curried_mult = - fn (%x) { - fn (%y) { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """ - ), - relay.Let( - curried_mult, - relay.Function( - [X], - relay.Function( - [Y], - relay.multiply(X, Y), - None, - [] - ), - None, - [] - ), - relay.Let( - _, - relay.Call(curried_mult, [relay.const(0)], None, None), - relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - ) - ) - ) - - # op - alpha_equal( - relay.fromtext("abs(1)"), - relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) - ) - -# Types - -def test_incomplete_type(): - assert alpha_equal( - relay.fromtext("let %_ : _ = (); ()"), - relay.Let( - _, - UNIT, - UNIT - ) - ) - -def test_builtin_types(): - for builtin_type in TYPES: - relay.fromtext("let %_ : {} = (); ()".format(builtin_type)) - -@nottest -def test_call_type(): - assert False - -def test_tensor_type(): - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), - relay.Let( - relay.Var("_", relay.TensorType((), "float32")), - UNIT, - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), - relay.Let( - relay.Var("_", relay.TensorType((1,), "float32")), - UNIT, - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), - relay.Let( - relay.Var("_", relay.TensorType((1, 1), "float32")), - UNIT, - UNIT - ) - ) - -def test_function_type(): - assert alpha_equal( - relay.fromtext( - """ - let %_: fn () -> int32 = fn () -> int32 { 0 }; () - """ - ), - relay.Let( - relay.Var("_", relay.FuncType([], int32, [], [])), - relay.Function([], relay.const(0), int32, []), - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () - """ - ), - relay.Let( - relay.Var("_", relay.FuncType([int32], int32, [], [])), - relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () - """ - ), - relay.Let( - relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), - relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), - UNIT - ) - ) - -def test_tuple_type(): - assert alpha_equal( - relay.fromtext( - """ - let %_: () = (); () - """), - relay.Let( - relay.Var("_", relay.TupleType([])), - UNIT, - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext( - """ - let %_: (int32,) = (0,); () - """), - relay.Let( - relay.Var("_", relay.TupleType([int32])), - relay.Tuple([relay.const(0)]), - UNIT - ) - ) - - assert alpha_equal( - relay.fromtext( - """ - let %_: (int32, int32) = (0, 1); () - """), - relay.Let( - relay.Var("_", relay.TupleType([int32, int32])), - relay.Tuple([relay.const(0), relay.const(1)]), - UNIT - ) - ) From f092434027de6e25e369e546df3ec23610e2838a Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Fri, 30 Nov 2018 20:13:58 -0800 Subject: [PATCH 62/64] revert source --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 822d145a7d3c..02f00e42e8fd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -51,11 +51,9 @@ stage("Sanity Check") { def make(docker_type, path, make_flag) { timeout(time: max_time, unit: 'MINUTES') { try { - sh "${docker_run} ${docker_type} source /etc/profile" sh "${docker_run} ${docker_type} ./tests/scripts/task_build.sh ${path} ${make_flag}" } catch (exc) { echo 'Incremental compilation failed. Fall back to build from scratch' - sh "${docker_run} ${docker_type} source /etc/profile" sh "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}" sh "${docker_run} ${docker_type} ./tests/scripts/task_build.sh ${path} ${make_flag}" } From 56535285652d9af59c0d7fd45b55770812558636 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sat, 1 Dec 2018 14:45:39 -0800 Subject: [PATCH 63/64] use ENV --- cmake/modules/ANTLR.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index c4fa24d52ecd..2ad11ed58865 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,6 +1,6 @@ if(USE_ANTLR) if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) - set(ANTLR4 "${JAVA_HOME}/bin/java -jar /usr/local/lib/antlr-4.7.1-complete.jar") + set(ANTLR4 "$ENV{JAVA_HOME}/bin/java -jar /usr/local/lib/antlr-4.7.1-complete.jar") set(RELAY_PARSER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) From 533972bd6a5d99009c2680dead7f6d2ab7ae1891 Mon Sep 17 00:00:00 2001 From: Josh Maxwell Pollock Date: Sat, 1 Dec 2018 22:53:06 -0800 Subject: [PATCH 64/64] modify how ANTLR runs in cmake --- cmake/modules/ANTLR.cmake | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index 2ad11ed58865..72eb5925bda0 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,6 +1,6 @@ if(USE_ANTLR) if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) - set(ANTLR4 "$ENV{JAVA_HOME}/bin/java -jar /usr/local/lib/antlr-4.7.1-complete.jar") + set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") set(RELAY_PARSER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) @@ -16,8 +16,8 @@ if(USE_ANTLR) # Generate ANTLR grammar for parsing. add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 - COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 WORKING_DIRECTORY ${RELAY_PARSER_DIR})