Skip to content

Commit

Permalink
Fix for dilation2d
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Mar 18, 2020
1 parent e13c422 commit 449138c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,12 @@ def _impl(inputs, attr, params, mod):

# Dilation2d
def _dilation2d():
def _impl(inputs, attr, params):
def _impl(inputs, attr, params, mod):
if 'data_format' not in attr:
attr['data_format'] = 'NHWC'

input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]]
input_shape = _infer_shape(inputs[0], mod)
weights_shape = _infer_shape(inputs[1], mod)

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_vanilla_loop_bound():
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.compat.v1.placeholder(shape=dshape, dtype=dtype, name=dname)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
Expand All @@ -339,7 +339,7 @@ def test_nested_loop_bound():
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.compat.v1.placeholder(shape=dshape, dtype=dtype, name=dname)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
Expand Down

0 comments on commit 449138c

Please sign in to comment.