diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index 2679ca61c6..859ed7be73 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -47,6 +47,16 @@ def __init__(self, multiplexer, logdir): self._multiplexer = multiplexer self._logdir = logdir + def _validate_experiment_id(self, experiment_id): + # This data provider doesn't consume the experiment ID at all, but + # as a courtesy to callers we require that it be a valid string, to + # help catch usage errors. + if not isinstance(experiment_id, str): + raise TypeError( + "experiment_id must be %r, but got %r: %r" + % (str, type(experiment_id), experiment_id) + ) + def _test_run_tag(self, run_tag_filter, run, tag): runs = run_tag_filter.runs if runs is not None and run not in runs: @@ -63,11 +73,11 @@ def _get_first_event_timestamp(self, run_name): return None def data_location(self, experiment_id): - del experiment_id # ignored + self._validate_experiment_id(experiment_id) return str(self._logdir) def list_runs(self, experiment_id): - del experiment_id # ignored for now + self._validate_experiment_id(experiment_id) return [ provider.Run( run_id=run, # use names as IDs @@ -78,6 +88,7 @@ def list_runs(self, experiment_id): ] def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): + self._validate_experiment_id(experiment_id) run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) return self._list( provider.ScalarTimeSeries, run_tag_content, run_tag_filter @@ -96,6 +107,7 @@ def read_scalars( return self._read(_convert_scalar_event, index) def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): + self._validate_experiment_id(experiment_id) run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) return self._list( provider.TensorTimeSeries, run_tag_content, run_tag_filter @@ -175,7 +187,7 @@ def _read(self, convert_event, index): def list_blob_sequences( self, experiment_id, plugin_name, run_tag_filter=None ): - del experiment_id # ignored for now + self._validate_experiment_id(experiment_id) if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) @@ -206,6 +218,7 @@ def list_blob_sequences( def read_blob_sequences( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): + self._validate_experiment_id(experiment_id) # TODO(davidsoergel, wchargin): consider images, etc. # Note this plugin_name can really just be 'graphs' for now; the # v2 cases are not handled yet. diff --git a/tensorboard/plugins/graph/graphs_plugin_test.py b/tensorboard/plugins/graph/graphs_plugin_test.py index ed0a9cd6b0..90becd6d00 100644 --- a/tensorboard/plugins/graph/graphs_plugin_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_test.py @@ -191,11 +191,16 @@ def test_info(self, plugin): }, } - self.assertItemsEqual(expected, plugin.info_impl()) + self.assertItemsEqual(expected, plugin.info_impl('eid')) @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) def test_graph_simple(self, plugin): - graph = self._get_graph(plugin, tag=None, is_conceptual=False) + graph = self._get_graph( + plugin, + tag=None, + is_conceptual=False, + experiment='eid', + ) node_names = set(node.name for node in graph.node) self.assertEqual({ 'k1', 'k2', 'pow', 'sub', 'expected', 'sub_1', 'error', @@ -210,6 +215,7 @@ def test_graph_large_attrs(self, plugin): plugin, tag=None, is_conceptual=False, + experiment='eid', limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, large_attrs_key=key) large_attrs = { diff --git a/tensorboard/plugins/scalar/scalars_plugin_test.py b/tensorboard/plugins/scalar/scalars_plugin_test.py index 2b145107ff..170cb11671 100644 --- a/tensorboard/plugins/scalar/scalars_plugin_test.py +++ b/tensorboard/plugins/scalar/scalars_plugin_test.py @@ -183,26 +183,26 @@ def test_index(self, plugin): }, }, # _RUN_WITH_HISTOGRAM omitted: No scalar data. - }, plugin.index_impl()) + }, plugin.index_impl('eid')) @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) def _test_scalars_json(self, plugin, run_name, tag_name, should_work=True): if should_work: (data, mime_type) = plugin.scalars_impl( - tag_name, run_name, None, scalars_plugin.OutputFormat.JSON) + tag_name, run_name, 'eid', scalars_plugin.OutputFormat.JSON) self.assertEqual('application/json', mime_type) self.assertEqual(len(data), self._STEPS) else: with self.assertRaises(errors.NotFoundError): plugin.scalars_impl( - self._SCALAR_TAG, run_name, None, scalars_plugin.OutputFormat.JSON + self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.JSON ) @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True): if should_work: (data, mime_type) = plugin.scalars_impl( - tag_name, run_name, None, scalars_plugin.OutputFormat.CSV) + tag_name, run_name, 'eid', scalars_plugin.OutputFormat.CSV) self.assertEqual('text/csv', mime_type) s = StringIO(data) reader = csv.reader(s) @@ -211,7 +211,7 @@ def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True): else: with self.assertRaises(errors.NotFoundError): plugin.scalars_impl( - self._SCALAR_TAG, run_name, None, scalars_plugin.OutputFormat.CSV + self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.CSV ) def test_scalars_json_with_legacy_scalars(self): @@ -264,7 +264,7 @@ def test_scalars_db_without_exp(self): self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) (data, mime_type) = self.plugin.scalars_impl( - self._SCALAR_TAG, self._RUN_WITH_SCALARS, None, + self._SCALAR_TAG, self._RUN_WITH_SCALARS, 'eid', scalars_plugin.OutputFormat.JSON) self.assertEqual('application/json', mime_type) # When querying DB-based backend without an experiment id, it returns all