Skip to content

Commit

Permalink
Tf2 test fixups (apache#5391)
Browse files Browse the repository at this point in the history
* Fix oversight in importing tf.compat.v1 as tf.

* Actually disable test for lstm in TF2.1

Since the testing framework actually uses pytest, the version
check needs to be moved.
  • Loading branch information
Ramana Radhakrishnan authored and trevor-m committed Jun 18, 2020
1 parent ff32fe9 commit 51e89b8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 4 additions & 1 deletion tests/python/frontend/tensorflow/test_bn_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
"""
import tvm
import numpy as np
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tvm import relay
from tensorflow.python.framework import graph_util

Expand Down
8 changes: 4 additions & 4 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,7 +1901,9 @@ def _get_tensorflow_output():

def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
_test_lstm_cell(1, 2, 1, 0.5, 'float32')


#######################################################################
Expand Down Expand Up @@ -3308,9 +3310,7 @@ def test_forward_isfinite():
test_forward_ptb()

# RNN
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
test_forward_lstm()
test_forward_lstm()

# Elementwise
test_forward_ceil()
Expand Down

0 comments on commit 51e89b8

Please sign in to comment.