Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def __init__(self, multiplexer, logdir):
self._multiplexer = multiplexer
self._logdir = logdir

def _validate_experiment_id(self, experiment_id):
# This data provider doesn't consume the experiment ID at all, but
# as a courtesy to callers we require that it be a valid string, to
# help catch usage errors.
if not isinstance(experiment_id, str):
raise TypeError(
"experiment_id must be %r, but got %r: %r"
% (str, type(experiment_id), experiment_id)
)

def _test_run_tag(self, run_tag_filter, run, tag):
runs = run_tag_filter.runs
if runs is not None and run not in runs:
Expand All @@ -63,11 +73,11 @@ def _get_first_event_timestamp(self, run_name):
return None

def data_location(self, experiment_id):
del experiment_id # ignored
self._validate_experiment_id(experiment_id)
return str(self._logdir)

def list_runs(self, experiment_id):
del experiment_id # ignored for now
self._validate_experiment_id(experiment_id)
return [
provider.Run(
run_id=run, # use names as IDs
Expand All @@ -78,6 +88,7 @@ def list_runs(self, experiment_id):
]

def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
self._validate_experiment_id(experiment_id)
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
return self._list(
provider.ScalarTimeSeries, run_tag_content, run_tag_filter
Expand All @@ -96,6 +107,7 @@ def read_scalars(
return self._read(_convert_scalar_event, index)

def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
self._validate_experiment_id(experiment_id)
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
return self._list(
provider.TensorTimeSeries, run_tag_content, run_tag_filter
Expand Down Expand Up @@ -175,7 +187,7 @@ def _read(self, convert_event, index):
def list_blob_sequences(
self, experiment_id, plugin_name, run_tag_filter=None
):
del experiment_id # ignored for now
self._validate_experiment_id(experiment_id)
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)

Expand Down Expand Up @@ -206,6 +218,7 @@ def list_blob_sequences(
def read_blob_sequences(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
self._validate_experiment_id(experiment_id)
# TODO(davidsoergel, wchargin): consider images, etc.
# Note this plugin_name can really just be 'graphs' for now; the
# v2 cases are not handled yet.
Expand Down
10 changes: 8 additions & 2 deletions tensorboard/plugins/graph/graphs_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,16 @@ def test_info(self, plugin):
},
}

self.assertItemsEqual(expected, plugin.info_impl())
self.assertItemsEqual(expected, plugin.info_impl('eid'))

@with_runs([_RUN_WITH_GRAPH_WITH_METADATA])
def test_graph_simple(self, plugin):
graph = self._get_graph(plugin, tag=None, is_conceptual=False)
graph = self._get_graph(
plugin,
tag=None,
is_conceptual=False,
experiment='eid',
)
node_names = set(node.name for node in graph.node)
self.assertEqual({
'k1', 'k2', 'pow', 'sub', 'expected', 'sub_1', 'error',
Expand All @@ -210,6 +215,7 @@ def test_graph_large_attrs(self, plugin):
plugin,
tag=None,
is_conceptual=False,
experiment='eid',
limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND,
large_attrs_key=key)
large_attrs = {
Expand Down
12 changes: 6 additions & 6 deletions tensorboard/plugins/scalar/scalars_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,26 @@ def test_index(self, plugin):
},
},
# _RUN_WITH_HISTOGRAM omitted: No scalar data.
}, plugin.index_impl())
}, plugin.index_impl('eid'))

@with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM])
def _test_scalars_json(self, plugin, run_name, tag_name, should_work=True):
if should_work:
(data, mime_type) = plugin.scalars_impl(
tag_name, run_name, None, scalars_plugin.OutputFormat.JSON)
tag_name, run_name, 'eid', scalars_plugin.OutputFormat.JSON)
self.assertEqual('application/json', mime_type)
self.assertEqual(len(data), self._STEPS)
else:
with self.assertRaises(errors.NotFoundError):
plugin.scalars_impl(
self._SCALAR_TAG, run_name, None, scalars_plugin.OutputFormat.JSON
self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.JSON
)

@with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM])
def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True):
if should_work:
(data, mime_type) = plugin.scalars_impl(
tag_name, run_name, None, scalars_plugin.OutputFormat.CSV)
tag_name, run_name, 'eid', scalars_plugin.OutputFormat.CSV)
self.assertEqual('text/csv', mime_type)
s = StringIO(data)
reader = csv.reader(s)
Expand All @@ -211,7 +211,7 @@ def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True):
else:
with self.assertRaises(errors.NotFoundError):
plugin.scalars_impl(
self._SCALAR_TAG, run_name, None, scalars_plugin.OutputFormat.CSV
self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.CSV
)

def test_scalars_json_with_legacy_scalars(self):
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_scalars_db_without_exp(self):
self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS)

(data, mime_type) = self.plugin.scalars_impl(
self._SCALAR_TAG, self._RUN_WITH_SCALARS, None,
self._SCALAR_TAG, self._RUN_WITH_SCALARS, 'eid',
scalars_plugin.OutputFormat.JSON)
self.assertEqual('application/json', mime_type)
# When querying DB-based backend without an experiment id, it returns all
Expand Down