Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element…
Browse files Browse the repository at this point in the history
… lists (apache#2242)
  • Loading branch information
alexeyr authored and srkreddy1238 committed Dec 29, 2018
1 parent 6d1f4c0 commit f6c3f99
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

input_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = attr['_input_shapes'][inputs[0]]

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
Expand All @@ -132,7 +132,7 @@ def _impl(inputs, attr, params):
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
Expand Down Expand Up @@ -185,13 +185,13 @@ def _impl(inputs, attr, params):

# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0]
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape]
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = attr['_input_shapes'][inputs[1]][0]
input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
Expand Down Expand Up @@ -484,7 +484,7 @@ def _impl(inputs, attr, params):

def _shape():
def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl

def _fill():
Expand Down Expand Up @@ -565,7 +565,7 @@ def _impl(inputs, attr, params):
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
data_dim = len(data_shape)
stride_dim = len(stride)

def _transform_mask(stride_dim, ellipsis_mask):
Expand Down Expand Up @@ -596,7 +596,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
Expand All @@ -606,19 +606,19 @@ def _transform_mask(stride_dim, ellipsis_mask):
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
Expand Down Expand Up @@ -684,8 +684,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4

in_data = _sym.reshape(in_data,
Expand Down Expand Up @@ -741,11 +741,10 @@ def _impl(inputs, attr, params):

def _rank():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
input_shape = attr['_input_shapes'][inputs[0]]

name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])])
params[name] = tvm.nd.array([len(input_shape)])
return _sym.Variable(name=name, shape=params[name].shape)
return _impl

Expand Down Expand Up @@ -829,7 +828,7 @@ def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
input_shape = attr['_input_shapes'][input_node]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
Expand Down Expand Up @@ -1018,8 +1017,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params,
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
batch_size = input_shape[0]
num_hidden = weight_shape[1] // 4

if layer == 0:
#Create initial states placeholder in case of first layer
Expand Down Expand Up @@ -1240,7 +1239,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
input_shapes[in_sym] = input_shape
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
Expand Down

0 comments on commit f6c3f99

Please sign in to comment.