Skip to content

Commit

Permalink
Add unit test verifying tf.Dataset support for dense_image_warp (#332) (
Browse files Browse the repository at this point in the history
  • Loading branch information
mels630 authored and WindQAQ committed Oct 31, 2019
1 parent e4c974d commit 22abf3c
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tensorflow_addons/image/dense_image_warp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,25 @@ def _check_interpolation_correctness(self,
shape,
image_type,
flow_type,
call_with_unknown_shapes=False,
num_probes=5):
"""Interpolate, and then assert correctness for a few query
locations."""
low_precision = image_type == "float16" or flow_type == "float16"
rand_image, rand_flows = self._get_random_image_and_flows(
shape, image_type, flow_type)

interp = dense_image_warp(
image=tf.convert_to_tensor(rand_image),
flow=tf.convert_to_tensor(rand_flows))
if call_with_unknown_shapes:
fn = dense_image_warp.get_concrete_function(
tf.TensorSpec(shape=None, dtype=image_type),
tf.TensorSpec(shape=None, dtype=flow_type))
interp = fn(
image=tf.convert_to_tensor(rand_image),
flow=tf.convert_to_tensor(rand_flows))
else:
interp = dense_image_warp(
image=tf.convert_to_tensor(rand_image),
flow=tf.convert_to_tensor(rand_flows))

for _ in range(num_probes):
batch_index = np.random.randint(0, shape[0])
Expand All @@ -189,6 +198,14 @@ def test_interpolation(self):
self._check_interpolation_correctness(
shape, im_type, flow_type)

def test_unknown_shapes(self):
"""Apply _check_interpolation_correctness() for a few sizes and check
for tf.Dataset compatibility."""
shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]]
for shape in shapes_to_try:
self._check_interpolation_correctness(shape, "float32", "float32",
True)

def test_gradients_exist(self):
"""Check that backprop can run.
Expand Down

0 comments on commit 22abf3c

Please sign in to comment.