Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
HawkAaron committed May 28, 2018
1 parent c0df8c4 commit b7a5b48
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions tensorflow_binding/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

B = 2; T = 4; U = 3; V = 6; blank = 5

acts = tf.nn.log_softmax(acts)
costs = rnnt_loss(acts, labels, input_length, label_length, blank)
grad = tf.gradients(costs, [acts])

Expand Down
9 changes: 6 additions & 3 deletions tensorflow_binding/tests/test_warprnnt_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def _run_rnnt(self, acts, labels, input_lengths, label_lengths,
labels_t = tf.constant(labels)
input_lengths_t = tf.constant(input_lengths)
label_lengths_t = tf.constant(label_lengths)

if not use_gpu: acts_t = tf.nn.log_softmax(acts_t)
costs = rnnt_loss(acts_t, labels_t, input_lengths_t, label_lengths_t, blank)

grads = tf.gradients(costs, [acts_t])
grads = tf.gradients(costs, [acts_t])[0]

self.assertShapeEqual(expected_costs, costs)

Expand Down Expand Up @@ -53,7 +55,8 @@ def test_forward(self):
labels_t = tf.constant(labels)
input_lengths_t = tf.constant(input_lengths)
label_lengths_t = tf.constant(label_lengths)
costs = rnnt_loss(acts_t, labels_t, input_lengths_t, label_lengths_t, blank)
acts_t = tf.nn.log_softmax(acts_t) # NOTE cpu
costs = rnnt_loss(acts_t, labels_t, input_lengths_t, label_lengths_t)
with self.test_session():
print(costs.eval())

Expand Down Expand Up @@ -91,7 +94,7 @@ def _test_multiple_batches(self, use_gpu):
input_lengths = np.array([4, 4], dtype=np.int32)
label_lengths = np.array([2, 2], dtype=np.int32)

self._run_rnnt(acts, labels, input_lengths, label_lengths, expected_costs, grads, blank, use_gpu)
self._run_rnnt(acts, labels, input_lengths, label_lengths, expected_costs, expected_grads, 0, use_gpu)

def test_multiple_batches_cpu(self):
self._test_multiple_batches(use_gpu=False)
Expand Down

0 comments on commit b7a5b48

Please sign in to comment.