diff --git a/tensorboard/uploader/exporter.py b/tensorboard/uploader/exporter.py index b87a0acdf6..123dd6c4c0 100644 --- a/tensorboard/uploader/exporter.py +++ b/tensorboard/uploader/exporter.py @@ -85,7 +85,8 @@ def __init__(self, reader_service_client, output_directory): self._api = reader_service_client self._outdir = output_directory parent_dir = os.path.dirname(self._outdir) - _mkdir_p(parent_dir) + if parent_dir: + _mkdir_p(parent_dir) try: os.mkdir(self._outdir) except OSError as e: diff --git a/tensorboard/uploader/exporter_test.py b/tensorboard/uploader/exporter_test.py index 9648fd162f..9072458cc2 100644 --- a/tensorboard/uploader/exporter_test.py +++ b/tensorboard/uploader/exporter_test.py @@ -199,6 +199,25 @@ def stream_experiments(request, **kwargs): self.assertIn("../authorized_keys", msg) mock_api_client.StreamExperimentData.assert_not_called() + def test_handles_outdir_with_no_slash(self): + oldcwd = os.getcwd() + try: + os.chdir(self.get_temp_dir()) + mock_api_client = self._create_mock_api_client() + mock_api_client.StreamExperiments.return_value = iter([ + export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]), + ]) + mock_api_client.StreamExperimentData.return_value = iter([ + export_service_pb2.StreamExperimentDataResponse() + ]) + + exporter = exporter_lib.TensorBoardExporter(mock_api_client, "outdir") + generator = exporter.export() + self.assertEqual(list(generator), ["123"]) + self.assertTrue(os.path.isdir("outdir")) + finally: + os.chdir(oldcwd) + def test_rejects_existing_directory(self): mock_api_client = self._create_mock_api_client() outdir = os.path.join(self.get_temp_dir(), "outdir")