diff --git a/tensorboard/plugins/histogram/histograms_plugin.py b/tensorboard/plugins/histogram/histograms_plugin.py index 95df054939..f0c8c20a30 100644 --- a/tensorboard/plugins/histogram/histograms_plugin.py +++ b/tensorboard/plugins/histogram/histograms_plugin.py @@ -23,10 +23,12 @@ from __future__ import print_function import collections +import csv import random import numpy as np import six +from six import StringIO from werkzeug import wrappers from tensorboard import plugin_util @@ -36,6 +38,10 @@ from tensorboard.plugins.histogram import metadata from tensorboard.util import tensor_util +class OutputFormat(object): + """An enum used to list the valid output formats for API calls.""" + JSON = 'json' + CSV = 'csv' class HistogramsPlugin(base_plugin.TBPlugin): """Histograms Plugin for TensorBoard. @@ -63,8 +69,8 @@ def __init__(self, context): def get_plugin_apps(self): return { - '/histograms': self.histograms_route, - '/tags': self.tags_route, + '/histograms': self.histograms_route, + '/tags': self.tags_route, } def is_active(self): @@ -124,7 +130,35 @@ def index_impl(self): return result - def histograms_impl(self, tag, run, downsample_to=None): + def _get_csv_response(self, events): + """ + + Args: + events: A list of (wall_time, step, tensor_as_list) event tuples. + + Returns: + CSV string representation of the given events. + """ + string_io = StringIO() + writer = csv.writer(string_io) + writer.writerow(['Wall time', 'Step', 'BinStart', 'BinEnd', 'BinValue']) + # Convert the events in a way that we can sensibly export + # them as csv. Therefore, we split the start, end, value of + # each bin by semicolon. + for e in events: + writer.writerows( + [ + e[0], + e[1], + ';'.join(['{:.17f}'.format(el[0]) for el in e[2]]), + ';'.join(['{:.17f}'.format(el[1]) for el in e[2]]), + ';'.join(['{}'.format(el[2]) for el in e[2]]), + ] + ) + return string_io.getvalue() + + def histograms_impl(self, tag, run, downsample_to=None, + output_format=OutputFormat.JSON): """Result of the form `(body, mime_type)`, or `ValueError`. At most `downsample_to` events will be returned. If this value is @@ -136,17 +170,17 @@ def histograms_impl(self, tag, run, downsample_to=None): cursor = db.cursor() # Prefetch the tag ID matching this run and tag. cursor.execute( - ''' - SELECT - tag_id - FROM Tags - JOIN Runs USING (run_id) - WHERE - Runs.run_name = :run - AND Tags.tag_name = :tag - AND Tags.plugin_name = :plugin - ''', - {'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME}) + ''' + SELECT + tag_id + FROM Tags + JOIN Runs USING (run_id) + WHERE + Runs.run_name = :run + AND Tags.tag_name = :tag + AND Tags.plugin_name = :plugin + ''', + {'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME}) row = cursor.fetchone() if not row: raise ValueError('No histogram tag %r for run %r' % (tag, run)) @@ -160,32 +194,32 @@ def histograms_impl(self, tag, run, downsample_to=None): # can be formally expressed as the following: # [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)] cursor.execute( - ''' + ''' + SELECT + MIN(step) AS step, + computed_time, + data, + dtype, + shape + FROM Tensors + INNER JOIN ( SELECT - MIN(step) AS step, - computed_time, - data, - dtype, - shape + MIN(step) AS min_step, + MAX(step) AS max_step FROM Tensors - INNER JOIN ( - SELECT - MIN(step) AS min_step, - MAX(step) AS max_step - FROM Tensors - /* Filter out NULL so we can use TensorSeriesStepIndex. */ - WHERE series = :tag_id AND step IS NOT NULL - ) - /* Ensure we omit reserved rows, which have NULL step values. */ + /* Filter out NULL so we can use TensorSeriesStepIndex. */ WHERE series = :tag_id AND step IS NOT NULL - /* Bucket rows into sample_size linearly spaced buckets, or do - no sampling if sample_size is NULL. */ - GROUP BY - IFNULL(:sample_size - 1, max_step - min_step) - * (step - min_step) / (max_step - min_step) - ORDER BY step - ''', - {'tag_id': tag_id, 'sample_size': downsample_to}) + ) + /* Ensure we omit reserved rows, which have NULL step values. */ + WHERE series = :tag_id AND step IS NOT NULL + /* Bucket rows into sample_size linearly spaced buckets, or do + no sampling if sample_size is NULL. */ + GROUP BY + IFNULL(:sample_size - 1, max_step - min_step) + * (step - min_step) / (max_step - min_step) + ORDER BY step + ''', + {'tag_id': tag_id, 'sample_size': downsample_to}) events = [(computed_time, step, self._get_values(data, dtype, shape)) for step, computed_time, data, dtype, shape in cursor] else: @@ -196,11 +230,15 @@ def histograms_impl(self, tag, run, downsample_to=None): raise ValueError('No histogram tag %r for run %r' % (tag, run)) if downsample_to is not None and len(tensor_events) > downsample_to: rand_indices = random.Random(0).sample( - six.moves.xrange(len(tensor_events)), downsample_to) + six.moves.xrange(len(tensor_events)), downsample_to) indices = sorted(rand_indices) tensor_events = [tensor_events[i] for i in indices] events = [[e.wall_time, e.step, tensor_util.make_ndarray(e.tensor_proto).tolist()] for e in tensor_events] + + if output_format == OutputFormat.CSV: + return (self._get_csv_response(events), 'text/csv') + return (events, 'application/json') def _get_values(self, data_blob, dtype_enum, shape_string): @@ -225,9 +263,11 @@ def histograms_route(self, request): """Given a tag and single run, return array of histogram values.""" tag = request.args.get('tag') run = request.args.get('run') + output_format = request.args.get('format') try: (body, mime_type) = self.histograms_impl( - tag, run, downsample_to=self.SAMPLE_SIZE) + tag, run, downsample_to=self.SAMPLE_SIZE, + output_format=output_format) code = 200 except ValueError as e: (body, mime_type) = (str(e), 'text/plain') diff --git a/tensorboard/plugins/histogram/histograms_plugin_test.py b/tensorboard/plugins/histogram/histograms_plugin_test.py index fadff76094..f7dcf6a457 100644 --- a/tensorboard/plugins/histogram/histograms_plugin_test.py +++ b/tensorboard/plugins/histogram/histograms_plugin_test.py @@ -20,9 +20,11 @@ from __future__ import print_function import collections +import csv import os.path import six +from six import StringIO from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -173,6 +175,33 @@ def test_histograms_with_histogram(self): self._test_histograms(self._RUN_WITH_HISTOGRAM, '%s/histogram_summary' % self._HISTOGRAM_TAG) + def _test_histograms_csv(self, run_name, tag_name, should_work=True): + self.set_up_with_runs([self._RUN_WITH_LEGACY_HISTOGRAM, + self._RUN_WITH_HISTOGRAM]) + if should_work: + (data, mime_type) = self.plugin.histograms_impl( + tag_name, run_name, + output_format=histograms_plugin.OutputFormat.CSV) + self.assertEqual('text/csv', mime_type) + s = StringIO(data) + reader = csv.reader(s) + self.assertEqual(['Wall time', 'Step', 'BinStart', + 'BinEnd', 'BinValue'], next(reader)) + self.assertEqual(len(list(reader)), self._STEPS) + else: + with self.assertRaises(KeyError): + self.plugin.histograms_impl( + self._HISTOGRAM_TAG, run_name, + output_format=histograms_plugin.OutputFormat.CSV) + + def test_histograms_csv_with_legacy_histograms(self): + self._test_histograms_csv(self._RUN_WITH_LEGACY_HISTOGRAM, + self._LEGACY_HISTOGRAM_TAG) + + def test_histograms_csv_with_histogram(self): + self._test_histograms_csv(self._RUN_WITH_HISTOGRAM, self._HISTOGRAM_TAG, + should_work=False) + def test_active_with_legacy_histogram(self): self.set_up_with_runs([self._RUN_WITH_LEGACY_HISTOGRAM]) self.assertTrue(self.plugin.is_active()) diff --git a/tensorboard/plugins/histogram/tf_histogram_dashboard/tf-histogram-dashboard.html b/tensorboard/plugins/histogram/tf_histogram_dashboard/tf-histogram-dashboard.html index 7de5d414c9..cd5be343d8 100644 --- a/tensorboard/plugins/histogram/tf_histogram_dashboard/tf-histogram-dashboard.html +++ b/tensorboard/plugins/histogram/tf_histogram_dashboard/tf-histogram-dashboard.html @@ -41,6 +41,14 @@