Skip to content

Commit

Permalink
Frontend WIP (apache#25)
Browse files Browse the repository at this point in the history
* parse normal call args

* pythonic list map

* Adds WIP type annnotation stuff

* Fleshes out LeNet test

* adds keyword arg parsing

* adds .mypy_cache to .gitignore

* linting

* more linting

* keyword args map keys are now LocalIds

* linting

* fix final linting errors
  • Loading branch information
joshpoll authored and jroesch committed Aug 16, 2018
1 parent 4d6315f commit feed7c7
Showing 1 changed file with 90 additions and 9 deletions.
99 changes: 90 additions & 9 deletions relay/python/relay/relay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pylint: disable=superfluous-parens
"""A decorator for rewriting Python code into Relay."""
import ast
from ast import literal_eval
import inspect
# from typing import Dict, List
from collections import OrderedDict
Expand Down Expand Up @@ -31,18 +32,64 @@ def compile_args_to_params(args):
params = []
for arg in args.args:
ident = arg.arg
arg_ty = None # Fix me
# All arguments must have types.
assert arg.annotation is not None
arg_ty = relay_type_from_annotation(arg.annotation)
param = Param(LocalId(ident), arg_ty)
params.append(param)
return params

# TODO: Return actual type nodes (i.e., not always bools), once they're added.
class TypeToRelay(ast.NodeVisitor):
"""Compiles a Python type to a Relay type."""

def generic_visit(self, node):
return BoolType()

#pylint: disable=invalid-name
def visit_Subscript(self, node):
name = node.value.id
if name == "Tensor":
tensor_params = node.slice.value.elts
assert len(tensor_params) == 2
# dtype = tensor_params[0].id
# shape = list(map(lambda x: int(x.node), tensor_params[1].elts))
return BoolType()
else:
raise Exception("Expected \"Tensor\"; got \"%s\"" % name)

#pylint: disable=invalid-name
def visit_Attribute(self, node):
typ = node.attr
raise Exception("Unknown Relay type \"%s\"" % typ)

#pylint: disable=invalid-name
def visit_Name(self, node):
"""Visit names"""
typ = node.id
if typ == "int":
return BoolType()
elif typ == "float":
return BoolType()
elif typ == "bool":
return BoolType()
elif typ == "str":
return BoolType()
else:
raise Exception("Unsupported Python builtin \"%s\"" % typ)


def relay_type_from_annotation(annotation):
return TypeToRelay().visit(annotation)


# We inherit from NodeVisitor to write a pass over the AST.
#
# Process a single definition and produce a single Relay Defunc.


class DefToRelay(ast.NodeVisitor):
"""Compile a single Python def to a Realy definition."""
"""Compiles a Python definition to a Relay definition."""
# local_scopes: List[Dict[LocalId, Expr]]

def __init__(self, python_def):
Expand All @@ -61,21 +108,56 @@ def visit_Return(self, return_node):
else:
raise Exception("return must have a value")

#pylint: disable=invalid-name
def visit_Num(self, num_node):
literal = literal_eval(num_node)
if isinstance(literal, int):
return IntLit(literal)
elif isinstance(literal, float):
return FloatLit(literal)
else:
raise Exception("unknown numeric literal")

#pylint: disable=invalid-name
def visit_Str(self, str_node):
s = str_node.s
return String(s)

#pylint: disable=invalid-name
def visit_List(self, list_node):
python_list = list_node.elts
relay_list = [self.visit(elt) for elt in python_list]
return TensorLit(relay_list)

#pylint: disable=invalid-name
def visit_NameConstant(self, nc_node):
singleton = nc_node.value
if singleton is None:
raise Exception("relay decorator does not support None")
else:
return BoolLit(singleton)

def visit_Call(self, call_node):
"""Transform a Python call into a Relay call"""
func = call_node.func
# args = call_node.args
args = call_node.args
# keywords = call_node.keywords

if isinstance(func, ast.Attribute):
if func.value.id == 'relay':
relay_func = IntrinsicId(func.attr)
else:
raise Exception(
"only supported namespace is relay right now") # improve me
"only supported namespace is relay right now") # improve me
else:
raise Exception("unsupported calls")
# Todo(jroesch): Handle args
return Call(relay_func, [])

# TODO(joshpoll): Handle args
relay_args = [self.visit(arg) for arg in args]

# relay_kwds = {LocalId(kwd.arg): self.visit(kwd.value) for kwd in keywords}

return Call(relay_func, relay_args)

def visit_Assign(self, assign_node):
targets = assign_node.targets
Expand Down Expand Up @@ -115,13 +197,12 @@ def run(self):
params = compile_args_to_params(args)
relay_body = self.compile_stmt_seq_to_body(body)
func = Function(params, relay_body)
defunc = Defn(GlobalId(name), None, func)
defunc = Defn(GlobalId(name), BoolType(), func)
return defunc


def compile_def_to_defn(func):
def_to_relay = DefToRelay(func)
return def_to_relay.run()
return DefToRelay(func).run()


def get_env():
Expand Down

0 comments on commit feed7c7

Please sign in to comment.