Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM. (#2757)
Browse files Browse the repository at this point in the history
* [FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM.

commit 76188a4
Author: Siva sivar.b@huawei.com
[NNVM][TENSORFLOW] bugfix. (#2444)

commit 6737739
Author: Ashutosh Parkhi ashutosh.parkhi@imgtec.com
[Tensorflow] Support for Crop (#2285)

commit f6c3f99
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)

commit e5d92e1
Author: Dominic Symes 36929632+dominicsymes@users.noreply.github.com
[FRONTEND][TENSORFLOW] Bugfix (#2326)

commit 00d509d
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Support Unstack and Split (#2105)

commit df9d3ad
Author: Siva sivar.b@huawei.com
[FRONTEND][TENSORFLOW] Bugfix (#2267)

commit d1a0c90
Author: Zhebin Jin zhebin.jzb@alibaba-inc.com
[FRONTEND][TENSORFLOW]Add Split and realdiv op support (#2123)
* Add Split and realdiv op support
* Fix the pad calculation in the case of dilated convolution

* 	* review comments

* 	* resnet fix.

* 	* review comments
  • Loading branch information
srkreddy1238 authored Mar 19, 2019
1 parent f63631f commit bb3c815
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 48 deletions.
23 changes: 16 additions & 7 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def is_gpu_available():
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) < 0:
if len(gpu_list) > 0:
print("Tensorflow GPU:", gpu_list)
return True
else:
Expand Down Expand Up @@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs):

if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)

def test_forward_pooling():
Expand Down Expand Up @@ -225,8 +225,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
if data_format == 'NHWC':
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations

nn_ops.conv2d(in_data,
in_filter,
Expand Down Expand Up @@ -898,7 +902,7 @@ def test_forward_mobilenet():

#######################################################################
# ResnetV2
# ---------
# --------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
Expand All @@ -912,8 +916,13 @@ def test_forward_resnetv2():

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
# PTB
Expand Down
Loading

0 comments on commit bb3c815

Please sign in to comment.