diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 42749e845a..65aafa88ac 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -72,11 +72,23 @@ def histogram(name, data, step=None, buckets=None, description=None): summary_scope = ( getattr(tf.summary.experimental, 'summary_scope', None) or tf.summary.summary_scope) - with summary_scope( - name, 'histogram_summary', values=[data, buckets, step]) as (tag, _): - tensor = _buckets(data, bucket_count=buckets) - return tf.summary.write( - tag=tag, tensor=tensor, step=step, metadata=summary_metadata) + + def histogram_summary(data, buckets, histogram_metadata, step): + with summary_scope( + name, 'histogram_summary', values=[data, buckets, step]) as (tag, _): + tensor = _buckets(data, bucket_count=buckets) + return tf.summary.write( + tag=tag, tensor=tensor, step=step, metadata=histogram_metadata) + + # `_buckets()` has dynamic output shapes which is not supported on TPU's. As so, place + # the bucketing ops on outside compilation cluster so that the function in executed on CPU. + # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this special + # handling once dynamic shapes are supported on TPU's. + if isinstance(tf.distribute.get_strategy(), + tf.distribute.experimental.TPUStrategy): + return tf.compat.v1.tpu.outside_compilation( + histogram_summary, data, buckets, summary_metadata, step) + return histogram_summary(data, buckets, summary_metadata, step) def _buckets(data, bucket_count=None):