Skip to content

Commit

Permalink
[Relay] Move prelude to text format (#3939)
Browse files Browse the repository at this point in the history
* Fix parser

* Doc fix

* Add module utility functions necessary for prelude

* Implement prelude in text format

* Remove programmatically constructed prelude defs

* Fix 0-arity type conses in pretty printer and test

* Make prelude loading backwards-compatible

* Fix patterns

* Improve some prelude defs

* Fix `ImportFromStd`

It needs to also follow the "add unchecked, add checked" pattern

* Lint roller

* Woops

* Address feedback

* Fix `test_list_constructor` VM test

* Fix `test_adt.py` failures
  • Loading branch information
weberlo authored and jroesch committed Sep 29, 2019
1 parent 9b46ace commit 2dac17d
Show file tree
Hide file tree
Showing 15 changed files with 1,166 additions and 1,102 deletions.
2 changes: 1 addition & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Var;
/*!
* \brief A variable node in the IR.
*
* A vraible is uniquely identified by its address.
* A variable is uniquely identified by its address.
*
* Each variable is only binded once in the following nodes:
* - Allocate
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
LOG(FATAL) << "Do not have a default for " << op->type_key();
throw;
}

private:
Expand Down
44 changes: 38 additions & 6 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,34 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);

/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);

/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
* \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);

/*!
* \brief Add a function to the global environment.
* \brief Add a type definition to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
*
* It does not do type inference as Add does.
* It does not do type inference as AddDef does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false);

/*!
* \brief Update a function in the global environment.
Expand All @@ -110,6 +123,13 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Update(const GlobalVar& var, const Function& func);

/*!
* \brief Update a type definition in the global environment.
* \param var The name of the global type definition to update.
* \param type The new ADT.
*/
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
Expand All @@ -130,13 +150,25 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;

/*!
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
tvm::Array<GlobalVar> GetGlobalVars() const;

/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;

/*!
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
LOG(FATAL) << "Do not have a default for " << op->type_key();
throw;
}

private:
Expand Down
121 changes: 56 additions & 65 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind)
self.type_var_scopes[0].appendleft((name, typ))
self.type_var_scopes[0].append((name, typ))
return typ

def mk_global_typ_var(self, name, kind):
Expand All @@ -242,7 +242,7 @@ def mk_global_typ_var(self, name, kind):
self.global_type_vars[name] = typ
return typ

# TODO: rethink whether we should have type constructors mixed with type vars.
# TODO(weberlo): rethink whether we should have type constructors mixed with type vars.
def mk_global_typ_cons(self, name, cons):
self._check_existing_typ_expr(name, cons)
self.global_type_vars[name] = cons
Expand Down Expand Up @@ -291,11 +291,15 @@ def visitGeneralIdent(self, ctx):
if name.startswith(type_prefix):
return ty.scalar_type(name)
# Next, look it up in the local then global type params.
type_param = lookup(self.type_var_scopes, name)
if type_param is None:
type_param = self.global_type_vars.get(name, None)
if type_param is not None:
return type_param
type_expr = lookup(self.type_var_scopes, name)
if type_expr is None:
type_expr = self.global_type_vars.get(name, None)
if type_expr is not None:
# Zero-arity constructor calls fall into the general ident case, so in that case,
# we construct a constructor call with no args.
if isinstance(type_expr, adt.Constructor) and not type_expr.inputs:
type_expr = expr.Call(type_expr, [])
return type_expr
# Check if it's an operator.
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
Expand All @@ -321,14 +325,12 @@ def visitGraphVar(self, ctx):

def visit_list(self, ctx_list) -> List[Any]:
""""Visit a list of contexts."""
# type: RelayParser.ContextParserRuleContext
assert isinstance(ctx_list, list)

return [self.visit(ctx) for ctx in ctx_list]

def getTypeExpr(self, ctx) -> Optional[ty.Type]:
def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]:
"""Return a (possibly None) Relay type."""
# type: : Optional[RelayParser.Type_Context]
if ctx is None:
return None

Expand Down Expand Up @@ -360,6 +362,10 @@ def visitOpIdent(self, ctx) -> op.Op:
def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
return self.visit(ctx.expr())

# pass through
def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr:
return self.visit(ctx.typeExpr())

# pass through
def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
return self.visit(ctx.expr())
Expand Down Expand Up @@ -466,7 +472,7 @@ def mk_func(
type_params = ctx.typeParamList()

if type_params is not None:
type_params = type_params.generalIdent()
type_params = type_params.typeExpr()
assert type_params
for ty_param in type_params:
name = ty_param.getText()
Expand Down Expand Up @@ -498,7 +504,8 @@ def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
ident_name = ctx.globalVar().getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
func = self.mk_func(ctx)
self.module[ident] = func

def handle_adt_header(
self,
Expand All @@ -512,7 +519,7 @@ def handle_adt_header(
type_params = []
else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
for type_ident in type_params.generalIdent()]
for type_ident in type_params.typeExpr()]
return adt_var, type_params

def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
Expand Down Expand Up @@ -552,8 +559,6 @@ def visitMatch(self, ctx: RelayParser.MatchContext):
else:
raise RuntimeError(f"unknown match type {match_type}")

# TODO: Will need some kind of type checking to know which ADT is being
# matched on.
match_data = self.visit(ctx.expr())
match_clauses = ctx.matchClauseList()
if match_clauses is None:
Expand All @@ -562,39 +567,36 @@ def visitMatch(self, ctx: RelayParser.MatchContext):
match_clauses = match_clauses.matchClause()
parsed_clauses = []
for clause in match_clauses:
constructor_name = clause.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
self.enter_var_scope()
patternList = clause.patternList()
if patternList is None:
patterns = []
else:
patterns = [self.visit(pattern) for pattern in patternList.pattern()]
pattern = self.visit(clause.pattern())
clause_body = self.visit(clause.expr())
self.exit_var_scope()
# TODO: Do we need to pass `None` if it's a 0-arity cons, or is an empty list fine?
parsed_clauses.append(adt.Clause(
adt.PatternConstructor(
constructor,
patterns
),
clause_body
))
parsed_clauses.append(adt.Clause(pattern, clause_body))
return adt.Match(match_data, parsed_clauses, complete=complete_match)

def visitPattern(self, ctx: RelayParser.PatternContext):
text = ctx.getText()
if text == "_":
return adt.PatternWildcard()
elif text.startswith("%"):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)
def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext):
return adt.PatternWildcard()

def visitVarPattern(self, ctx: RelayParser.VarPatternContext):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)

def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext):
constructor_name = ctx.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
pattern_list = ctx.patternList()
if pattern_list is None:
patterns = []
else:
raise ParseError(f"invalid pattern syntax \"{text}\"")
patterns = [self.visit(pattern) for pattern in pattern_list.pattern()]
return adt.PatternConstructor(constructor, patterns)

def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext):
return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()])

def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
return (self.visit_list(ctx.exprList().expr()), None)
Expand All @@ -610,16 +612,14 @@ def call(self, func, args, attrs, type_args):
return expr.Call(func, args, attrs, type_args)

@spanify
def visitCall(self, ctx: RelayParser.CallContext):
# type: (RelayParser.CallContext) -> expr.Call
def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call:
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
res = self.call(func, args, attrs, [])
return res

@spanify
def visitIfElse(self, ctx: RelayParser.IfElseContext):
# type: (RelayParser.IfElseContext) -> expr.If
def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If:
"""Construct a Relay If node. Creates a new scope for each branch."""
cond = self.visit(ctx.expr())

Expand All @@ -634,8 +634,7 @@ def visitIfElse(self, ctx: RelayParser.IfElseContext):
return expr.If(cond, true_branch, false_branch)

@spanify
def visitGraph(self, ctx: RelayParser.GraphContext):
# type: (RelayParser.GraphContext) -> expr.Expr
def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr:
"""Visit a graph variable assignment."""
graph_nid = int(ctx.graphVar().getText()[1:])

Expand All @@ -655,28 +654,24 @@ def visitGraph(self, ctx: RelayParser.GraphContext):
# Types

# pylint: disable=unused-argument
def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext):
# type (RelayParser.IncompleteTypeContext) -> None:
def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None:
return None

def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
func = self.visit(ctx.generalIdent())
args = [self.visit(arg) for arg in ctx.typeParamList().generalIdent()]
args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()]
return ty.TypeCall(func, args)

def visitParensShape(self, ctx: RelayParser.ParensShapeContext):
# type: (RelayParser.ParensShapeContext) -> int
def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int:
return self.visit(ctx.shape())

def visitShapeList(self, ctx: RelayParser.ShapeListContext):
# type: (RelayParser.ShapeListContext) -> List[int]
def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]:
return self.visit_list(ctx.shape())

def visitTensor(self, ctx: RelayParser.TensorContext):
return tuple(self.visit_list(ctx.expr()))

def visitTensorType(self, ctx: RelayParser.TensorTypeContext):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType:
"""Create a simple tensor type. No generics."""

shape = self.visit(ctx.shapeList())
Expand All @@ -689,21 +684,18 @@ def visitTensorType(self, ctx: RelayParser.TensorTypeContext):

return ty.TensorType(shape, dtype)

def visitTupleType(self, ctx: RelayParser.TupleTypeContext):
# type: (RelayParser.TupleTypeContext) -> ty.TupleType
def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType:
return ty.TupleType(self.visit_list(ctx.typeExpr()))

def visitFuncType(self, ctx: RelayParser.FuncTypeContext):
# type: (RelayParser.FuncTypeContext) -> ty.FuncType
def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType:
types = self.visit_list(ctx.typeExpr())

arg_types = types[:-1]
ret_type = types[-1]

return ty.FuncType(arg_types, ret_type, [], None)

def make_parser(data):
# type: (str) -> RelayParser
def make_parser(data: str) -> RelayParser:
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
Expand Down Expand Up @@ -738,8 +730,7 @@ def reportAttemptingFullContext(self,
def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]:
"""Parse a Relay program."""
if data == "":
raise ParseError("cannot parse the empty string.")
Expand Down
Loading

0 comments on commit 2dac17d

Please sign in to comment.