Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Move prelude to text format #3939

Merged
merged 15 commits into from
Sep 29, 2019
2 changes: 1 addition & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Var;
/*!
* \brief A variable node in the IR.
*
* A vraible is uniquely identified by its address.
* A variable is uniquely identified by its address.
*
* Each variable is only binded once in the following nodes:
* - Allocate
Expand Down
40 changes: 36 additions & 4 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,34 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);

/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);

/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);

/*!
* \brief Add a function to the global environment.
* \brief Add a type definition to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*
* It does not do type inference as Add does.
* It does not do type inference as AddDef does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false);

/*!
* \brief Update a function in the global environment.
Expand All @@ -110,6 +123,13 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Update(const GlobalVar& var, const Function& func);

/*!
* \brief Update a type definition in the global environment.
* \param var The name of the global type definition to update.
* \param func The new function.
weberlo marked this conversation as resolved.
Show resolved Hide resolved
*/
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
Expand All @@ -130,13 +150,25 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;

/*!
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
tvm::Array<GlobalVar> GetGlobalVars() const;

/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;

/*!
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
Expand Down
119 changes: 55 additions & 64 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def mk_global_typ_var(self, name, kind):
self.global_type_vars[name] = typ
return typ

# TODO: rethink whether we should have type constructors mixed with type vars.
# TODO(weberlo): rethink whether we should have type constructors mixed with type vars.
def mk_global_typ_cons(self, name, cons):
self._check_existing_typ_expr(name, cons)
self.global_type_vars[name] = cons
Expand Down Expand Up @@ -291,11 +291,15 @@ def visitGeneralIdent(self, ctx):
if name.startswith(type_prefix):
return ty.scalar_type(name)
# Next, look it up in the local then global type params.
type_param = lookup(self.type_var_scopes, name)
if type_param is None:
type_param = self.global_type_vars.get(name, None)
if type_param is not None:
return type_param
type_expr = lookup(self.type_var_scopes, name)
if type_expr is None:
type_expr = self.global_type_vars.get(name, None)
if type_expr is not None:
# Zero-arity constructor calls fall into the general ident case, so in that case,
# we construct a constructor call with no args.
if isinstance(type_expr, adt.Constructor) and not len(type_expr.inputs):
type_expr = expr.Call(type_expr, [])
return type_expr
# Check if it's an operator.
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
Expand All @@ -321,14 +325,12 @@ def visitGraphVar(self, ctx):

def visit_list(self, ctx_list) -> List[Any]:
""""Visit a list of contexts."""
# type: RelayParser.ContextParserRuleContext
assert isinstance(ctx_list, list)

return [self.visit(ctx) for ctx in ctx_list]

def getTypeExpr(self, ctx) -> Optional[ty.Type]:
def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]:
"""Return a (possibly None) Relay type."""
# type: : Optional[RelayParser.Type_Context]
if ctx is None:
return None

Expand Down Expand Up @@ -360,6 +362,10 @@ def visitOpIdent(self, ctx) -> op.Op:
def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
return self.visit(ctx.expr())

# pass through
def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr:
return self.visit(ctx.typeExpr())

# pass through
def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
return self.visit(ctx.expr())
Expand Down Expand Up @@ -466,7 +472,7 @@ def mk_func(
type_params = ctx.typeParamList()

if type_params is not None:
type_params = type_params.generalIdent()
type_params = type_params.typeExpr()
assert type_params
for ty_param in type_params:
name = ty_param.getText()
Expand Down Expand Up @@ -498,7 +504,8 @@ def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
ident_name = ctx.globalVar().getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
func = self.mk_func(ctx)
self.module[ident] = func

def handle_adt_header(
self,
Expand All @@ -512,7 +519,7 @@ def handle_adt_header(
type_params = []
else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
for type_ident in type_params.generalIdent()]
for type_ident in type_params.typeExpr()]
return adt_var, type_params

def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
Expand Down Expand Up @@ -552,8 +559,6 @@ def visitMatch(self, ctx: RelayParser.MatchContext):
else:
raise RuntimeError(f"unknown match type {match_type}")

# TODO: Will need some kind of type checking to know which ADT is being
# matched on.
match_data = self.visit(ctx.expr())
match_clauses = ctx.matchClauseList()
if match_clauses is None:
Expand All @@ -562,39 +567,36 @@ def visitMatch(self, ctx: RelayParser.MatchContext):
match_clauses = match_clauses.matchClause()
parsed_clauses = []
for clause in match_clauses:
constructor_name = clause.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
self.enter_var_scope()
patternList = clause.patternList()
if patternList is None:
patterns = []
else:
patterns = [self.visit(pattern) for pattern in patternList.pattern()]
pattern = self.visit(clause.pattern())
clause_body = self.visit(clause.expr())
self.exit_var_scope()
# TODO: Do we need to pass `None` if it's a 0-arity cons, or is an empty list fine?
parsed_clauses.append(adt.Clause(
adt.PatternConstructor(
constructor,
patterns
),
clause_body
))
parsed_clauses.append(adt.Clause(pattern, clause_body))
return adt.Match(match_data, parsed_clauses, complete=complete_match)

def visitPattern(self, ctx: RelayParser.PatternContext):
text = ctx.getText()
if text == "_":
return adt.PatternWildcard()
elif text.startswith("%"):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)
def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext):
return adt.PatternWildcard()

def visitVarPattern(self, ctx: RelayParser.VarPatternContext):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)

def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext):
constructor_name = ctx.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
pattern_list = ctx.patternList()
if pattern_list is None:
patterns = []
else:
raise ParseError(f"invalid pattern syntax \"{text}\"")
patterns = [self.visit(pattern) for pattern in pattern_list.pattern()]
return adt.PatternConstructor(constructor, patterns)

def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext):
return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()])

def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
return (self.visit_list(ctx.exprList().expr()), None)
Expand All @@ -610,16 +612,14 @@ def call(self, func, args, attrs, type_args):
return expr.Call(func, args, attrs, type_args)

@spanify
def visitCall(self, ctx: RelayParser.CallContext):
# type: (RelayParser.CallContext) -> expr.Call
def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call:
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
res = self.call(func, args, attrs, [])
return res

@spanify
def visitIfElse(self, ctx: RelayParser.IfElseContext):
# type: (RelayParser.IfElseContext) -> expr.If
def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If:
"""Construct a Relay If node. Creates a new scope for each branch."""
cond = self.visit(ctx.expr())

Expand All @@ -634,8 +634,7 @@ def visitIfElse(self, ctx: RelayParser.IfElseContext):
return expr.If(cond, true_branch, false_branch)

@spanify
def visitGraph(self, ctx: RelayParser.GraphContext):
# type: (RelayParser.GraphContext) -> expr.Expr
def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr:
"""Visit a graph variable assignment."""
graph_nid = int(ctx.graphVar().getText()[1:])

Expand All @@ -655,28 +654,24 @@ def visitGraph(self, ctx: RelayParser.GraphContext):
# Types

# pylint: disable=unused-argument
def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext):
# type (RelayParser.IncompleteTypeContext) -> None:
def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None:
return None

def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
func = self.visit(ctx.generalIdent())
args = [self.visit(arg) for arg in ctx.typeParamList().generalIdent()]
args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()]
return ty.TypeCall(func, args)

def visitParensShape(self, ctx: RelayParser.ParensShapeContext):
# type: (RelayParser.ParensShapeContext) -> int
def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int:
return self.visit(ctx.shape())

def visitShapeList(self, ctx: RelayParser.ShapeListContext):
# type: (RelayParser.ShapeListContext) -> List[int]
def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]:
return self.visit_list(ctx.shape())

def visitTensor(self, ctx: RelayParser.TensorContext):
return tuple(self.visit_list(ctx.expr()))

def visitTensorType(self, ctx: RelayParser.TensorTypeContext):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType:
"""Create a simple tensor type. No generics."""

shape = self.visit(ctx.shapeList())
Expand All @@ -689,21 +684,18 @@ def visitTensorType(self, ctx: RelayParser.TensorTypeContext):

return ty.TensorType(shape, dtype)

def visitTupleType(self, ctx: RelayParser.TupleTypeContext):
# type: (RelayParser.TupleTypeContext) -> ty.TupleType
def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType:
return ty.TupleType(self.visit_list(ctx.typeExpr()))

def visitFuncType(self, ctx: RelayParser.FuncTypeContext):
# type: (RelayParser.FuncTypeContext) -> ty.FuncType
def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType:
types = self.visit_list(ctx.typeExpr())

arg_types = types[:-1]
ret_type = types[-1]

return ty.FuncType(arg_types, ret_type, [], None)

def make_parser(data):
# type: (str) -> RelayParser
def make_parser(data: str) -> RelayParser:
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
Expand Down Expand Up @@ -738,8 +730,7 @@ def reportAttemptingFullContext(self,
def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]:
"""Parse a Relay program."""
if data == "":
raise ParseError("cannot parse the empty string.")
Expand Down
Loading