diff --git a/tensorboard/plugins/image/summary_v2.py b/tensorboard/plugins/image/summary_v2.py index 9a5e8c2b9a..de184bb611 100644 --- a/tensorboard/plugins/image/summary_v2.py +++ b/tensorboard/plugins/image/summary_v2.py @@ -68,21 +68,25 @@ def image(name, tf.summary.summary_scope) with summary_scope( name, 'image_summary', values=[data, max_outputs, step]) as (tag, _): - tf.debugging.assert_rank(data, 4) - tf.debugging.assert_non_negative(max_outputs) - images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True) - limited_images = images[:max_outputs] - encoded_images = tf.map_fn(tf.image.encode_png, limited_images, - dtype=tf.string, - name='encode_each_image') - # Workaround for map_fn returning float dtype for an empty elems input. - encoded_images = tf.cond( - tf.shape(input=encoded_images)[0] > 0, - lambda: encoded_images, lambda: tf.constant([], tf.string)) - image_shape = tf.shape(input=images) - dimensions = tf.stack([tf.as_string(image_shape[2], name='width'), - tf.as_string(image_shape[1], name='height')], - name='dimensions') - tensor = tf.concat([dimensions, encoded_images], axis=0) + def _encode_image_data(): + tf.debugging.assert_rank(data, 4) + tf.debugging.assert_non_negative(max_outputs) + images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True) + limited_images = images[:max_outputs] + encoded_images = tf.map_fn(tf.image.encode_png, limited_images, + dtype=tf.string, + name='encode_each_image') + # Workaround for map_fn returning float dtype for an empty elems input. + encoded_images = tf.cond( + tf.shape(input=encoded_images)[0] > 0, + lambda: encoded_images, lambda: tf.constant([], tf.string)) + image_shape = tf.shape(input=images) + dimensions = tf.stack([tf.as_string(image_shape[2], name='width'), + tf.as_string(image_shape[1], name='height')], + name='dimensions') + return tf.concat([dimensions, encoded_images], axis=0) + + # To ensure that image encoding logic is only executed when summaries + # are written, we pass callable to `tensor` parameter. return tf.summary.write( - tag=tag, tensor=tensor, step=step, metadata=summary_metadata) + tag=tag, tensor=_encode_image_data, step=step, metadata=summary_metadata)