Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM.
Browse files Browse the repository at this point in the history
commit f347b52
Author: Yong Wu <yongwu@amazon.com>
    Get tags of saved model automatically

commit 916576c
Author: Zhi Chen <chzhi@amazon.com>
    Support TensorFlow saved model
    TF parser: return the consistent error message to error handler

commit f1782f3
Author: Yong Wu <yongwu@amazon.com>
    Add tf parser wrapper, infer shape automatically

commit 76188a4
Author: Siva <sivar.b@huawei.com>
    [NNVM][TENSORFLOW] bugfix. (apache#2444)

commit 6737739
Author: Ashutosh Parkhi <ashutosh.parkhi@imgtec.com>
    [Tensorflow] Support for Crop (apache#2285)

commit f6c3f99
Author: Alexey Romanov <alexey.v.romanov@gmail.com>
    [FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (apache#2242)

commit e5d92e1
Author: Dominic Symes <36929632+dominicsymes@users.noreply.github.com>
    [FRONTEND][TENSORFLOW] Bugfix (apache#2326)

commit 00d509d
Author: Alexey Romanov <alexey.v.romanov@gmail.com>
    [FRONTEND][TENSORFLOW] Support Unstack and Split (apache#2105)

commit df9d3ad
Author: Siva <sivar.b@huawei.com>
    [FRONTEND][TENSORFLOW] Bugfix (apache#2267)

commit d1a0c90
Author: Zhebin Jin <zhebin.jzb@alibaba-inc.com>
    [FRONTEND][TENSORFLOW]Add Split and realdiv op support (apache#2123)
    * Add Split and realdiv op support
    * Fix the pad calculation in the case of dilated convolution
  • Loading branch information
srkreddy1238 committed Mar 9, 2019
1 parent a7e35fc commit 566f2c1
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 60 deletions.
12 changes: 5 additions & 7 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def is_gpu_available():
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) < 0:
if len(gpu_list) > 0:
print("Tensorflow GPU:", gpu_list)
return True
else:
Expand Down Expand Up @@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs):

if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)

def test_forward_pooling():
Expand Down Expand Up @@ -240,9 +240,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
def test_forward_convolution():
if is_gpu_available():
_test_convolution([4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution([4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW')

_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand Down Expand Up @@ -899,7 +897,7 @@ def test_forward_mobilenet():
#######################################################################
# ResnetV2
# ---------
def test_forward_resnetv2():
def _test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
Expand All @@ -912,7 +910,7 @@ def test_forward_resnetv2():

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output[0].shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
Expand Down Expand Up @@ -1235,7 +1233,7 @@ def test_forward_rel_ops():
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
#_test_forward_resnetv2()
test_forward_ptb()

# RNN
Expand Down
Loading

0 comments on commit 566f2c1

Please sign in to comment.