Skip to content

Commit

Permalink
[TENSORLFOW] PlaceholderWithDefault (limited) implementation. (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and Wei Chen committed Jun 26, 2019
1 parent 30d026f commit 6bcc40d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
# Give priority to user argument.
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
Expand Down Expand Up @@ -1800,7 +1800,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

attr = self._parse_attr(node.attr)

elif node.op != "Placeholder":
elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault':
# Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

Expand Down Expand Up @@ -1925,7 +1925,7 @@ def _parse_import_prerequisites(self, graph):
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
pass
elif node.op == "Const":
pass
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,24 @@ def test_forward_reduce_prod():
_test_forward_reduce_prod((5, 5), 0, True)
_test_forward_reduce_prod((5, 5), 1, True)


#######################################################################
# PlaceholderWithDefault
# ----------------------
def test_placeholder():
with tf.Graph().as_default():
in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
var1 = tf.Variable(in_data1, name='in1')
var2 = array_ops.placeholder_with_default(var1, None, name='place1')

in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2')

out1 = tf.math.add(var1, var2, name='out1')
out2 = tf.math.add(out1, place1, name='out2')

compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True)

#######################################################################
# Main
# ----
Expand Down Expand Up @@ -1590,6 +1608,7 @@ def test_forward_reduce_prod():
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
test_placeholder()

# NN
test_forward_convolution()
Expand Down

0 comments on commit 6bcc40d

Please sign in to comment.