Skip to content

Commit

Permalink
[Relay][Frontend] Keras Support (apache#2336)
Browse files Browse the repository at this point in the history
  • Loading branch information
Huyuwei authored and tqchen committed Jan 5, 2019
1 parent 41656e1 commit 174418b
Show file tree
Hide file tree
Showing 5 changed files with 999 additions and 8 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from __future__ import absolute_import

from .mxnet import from_mxnet
from .keras import from_keras
27 changes: 25 additions & 2 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common utilities"""
from __future__ import absolute_import as _abs
from .. import expr as _expr


class RequiredAttr(object):
Expand Down Expand Up @@ -181,8 +182,6 @@ def get_int_list(self, key, default=RequiredAttr()):
raise AttributeError("Required attribute {} not found.".format(key))
return default



def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute
Expand All @@ -204,3 +203,27 @@ def get_bool(self, key, default=RequiredAttr()):
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default


class ExprTable(object):
"""Table storing Relay expressions by names."""
def __init__(self):
self.exprs = {}
self.params = {}
self.const_ctr = 1

def new_const(self, value, shape=None, dtype="float32"):
name = "_param_%d" % (self.const_ctr)
if hasattr(value, "shape"):
shape = value.shape
self.const_ctr += 1
self.params[name] = value
self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
return self.exprs[name]

def get_expr(self, name):
return self.exprs[name]

def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr
Loading

0 comments on commit 174418b

Please sign in to comment.