diff --git a/tensorboard/data/grpc_provider.py b/tensorboard/data/grpc_provider.py index 19b2082b02..f83b60c98e 100644 --- a/tensorboard/data/grpc_provider.py +++ b/tensorboard/data/grpc_provider.py @@ -47,7 +47,25 @@ def __init__(self, addr, stub): self._stub = stub def data_location(self, ctx, *, experiment_id): - return "grpc://%s" % (self._addr,) + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = experiment_id + with _translate_grpc_error(): + res = self._stub.GetExperiment(req) + return res.data_location + + def experiment_metadata(self, ctx, *, experiment_id): + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = experiment_id + with _translate_grpc_error(): + res = self._stub.GetExperiment(req) + if not (res.name or res.description or res.HasField("creation_time")): + return None + res = provider.ExperimentMetadata( + experiment_name=res.name, + experiment_description=res.description, + creation_time=_timestamp_proto_to_float(res.creation_time), + ) + return res def list_plugins(self, ctx, *, experiment_id): req = data_provider_pb2.ListPluginsRequest() @@ -306,3 +324,8 @@ def _populate_rtf(run_tag_filter, rtf_proto): rtf_proto.runs.names[:] = sorted(run_tag_filter.runs) if run_tag_filter.tags is not None: rtf_proto.tags.names[:] = sorted(run_tag_filter.tags) + + +def _timestamp_proto_to_float(ts): + """Converts `timestamp_pb2.Timestamp` to float seconds since epoch.""" + return ts.ToNanoseconds() / 1e9 diff --git a/tensorboard/data/grpc_provider_test.py b/tensorboard/data/grpc_provider_test.py index 9f1c75f156..11513bbb1f 100644 --- a/tensorboard/data/grpc_provider_test.py +++ b/tensorboard/data/grpc_provider_test.py @@ -49,8 +49,73 @@ def setUp(self): self.ctx = context.RequestContext() def test_data_location(self): + res = data_provider_pb2.GetExperimentResponse() + res.data_location = "./logs/mnist" + self.stub.GetExperiment.return_value = res + actual = self.provider.data_location(self.ctx, experiment_id="123") - self.assertEqual(actual, "grpc://localhost:0") + self.assertEqual(actual, "./logs/mnist") + + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = "123" + self.stub.GetExperiment.assert_called_once_with(req) + + def test_experiment_metadata_when_only_data_location_set(self): + res = data_provider_pb2.GetExperimentResponse() + self.stub.GetExperiment.return_value = res + + actual = self.provider.experiment_metadata( + self.ctx, experiment_id="123" + ) + self.assertIsNone(actual) + + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = "123" + self.stub.GetExperiment.assert_called_once_with(req) + + def test_experiment_metadata_with_partial_metadata(self): + res = data_provider_pb2.GetExperimentResponse() + res.name = "mnist" + self.stub.GetExperiment.return_value = res + + actual = self.provider.experiment_metadata( + self.ctx, experiment_id="123" + ) + self.assertEqual( + actual, + provider.ExperimentMetadata( + experiment_name="mnist", + experiment_description="", + creation_time=0, + ), + ) + + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = "123" + self.stub.GetExperiment.assert_called_once_with(req) + + def test_experiment_metadata_with_creation_time(self): + res = data_provider_pb2.GetExperimentResponse() + res.name = "mnist" + res.description = "big breakthroughs" + res.creation_time.FromMilliseconds(1500) + self.stub.GetExperiment.return_value = res + + actual = self.provider.experiment_metadata( + self.ctx, experiment_id="123" + ) + self.assertEqual( + actual, + provider.ExperimentMetadata( + experiment_name="mnist", + experiment_description="big breakthroughs", + creation_time=1.5, + ), + ) + + req = data_provider_pb2.GetExperimentRequest() + req.experiment_id = "123" + self.stub.GetExperiment.assert_called_once_with(req) def test_list_plugins(self): res = data_provider_pb2.ListPluginsResponse() diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py index 086f468619..e4e86f4fb0 100644 --- a/tensorboard/data/provider.py +++ b/tensorboard/data/provider.py @@ -404,6 +404,35 @@ def experiment_description(self): def creation_time(self): return self._creation_time + def __eq__(self, other): + if not isinstance(other, ExperimentMetadata): + return False + if self._experiment_name != other._experiment_name: + return False + if self._experiment_description != other._experiment_description: + return False + if self._creation_time != other._creation_time: + return False + return True + + def __hash__(self): + return hash( + ( + self._experiment_name, + self._experiment_description, + self._creation_time, + ) + ) + + def __repr__(self): + return "ExperimentMetadata(%s)" % ", ".join( + ( + "experiment_name=%r" % (self._experiment_name,), + "experiment_description=%r" % (self._experiment_description,), + "creation_time=%r" % (self._creation_time,), + ) + ) + class Run(object): """Metadata about a run. diff --git a/tensorboard/data/provider_test.py b/tensorboard/data/provider_test.py index a58b35373f..82380ee539 100644 --- a/tensorboard/data/provider_test.py +++ b/tensorboard/data/provider_test.py @@ -38,6 +38,31 @@ def test_attributes(self): self.assertEqual(e1.experiment_description, "Experiment on Foo") self.assertEqual(e1.creation_time, 1.25) + def test_eq(self): + def md(**kwargs): + kwargs.setdefault("experiment_name", "FooExperiment") + kwargs.setdefault("experiment_description", "Experiment on Foo") + kwargs.setdefault("creation_time", 1.25) + return provider.ExperimentMetadata(**kwargs) + + a1 = md() + a2 = md() + b = md(experiment_name="BarExperiment") + self.assertEqual(a1, a2) + self.assertNotEqual(a1, b) + self.assertNotEqual(b, object()) + + def test_repr(self): + x = provider.ExperimentMetadata( + experiment_name="FooExperiment", + experiment_description="Experiment on Foo", + creation_time=1.25, + ) + repr_ = repr(x) + self.assertIn(repr(x.experiment_name), repr_) + self.assertIn(repr(x.experiment_description), repr_) + self.assertIn(repr(x.creation_time), repr_) + class RunTest(tb_test.TestCase): def test_eq(self): diff --git a/tensorboard/data/server/cli.rs b/tensorboard/data/server/cli.rs index 189c793e51..7c3986fc06 100644 --- a/tensorboard/data/server/cli.rs +++ b/tensorboard/data/server/cli.rs @@ -183,6 +183,8 @@ pub async fn main() -> Result<(), Box> { eprintln!("listening on {:?}", bound); } + let data_location = opts.logdir.display().to_string(); + // Leak the commit object, since the Tonic server must have only 'static references. This only // leaks the outer commit structure (of constant size), not the pointers to the actual data. let commit: &'static Commit = Box::leak(Box::new(Commit::new())); @@ -215,7 +217,10 @@ pub async fn main() -> Result<(), Box> { }) .expect("failed to spawn reloader thread"); - let handler = DataProviderHandler { commit }; + let handler = DataProviderHandler { + data_location, + commit, + }; Server::builder() .add_service(TensorBoardDataProviderServer::new(handler)) .serve_with_incoming(TcpListenerStream::new(listener)) diff --git a/tensorboard/data/server/server.rs b/tensorboard/data/server/server.rs index cbb9a9413b..23cee27e8e 100644 --- a/tensorboard/data/server/server.rs +++ b/tensorboard/data/server/server.rs @@ -36,6 +36,7 @@ use data::tensor_board_data_provider_server::TensorBoardDataProvider; /// Data provider gRPC service implementation. #[derive(Debug)] pub struct DataProviderHandler { + pub data_location: String, pub commit: &'static Commit, } @@ -62,7 +63,10 @@ impl TensorBoardDataProvider for DataProviderHandler { &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("not yet implemented")) + Ok(Response::new(data::GetExperimentResponse { + data_location: self.data_location.clone(), + ..Default::default() + })) } async fn list_plugins( @@ -641,11 +645,24 @@ mod tests { fn sample_handler(commit: Commit) -> DataProviderHandler { DataProviderHandler { + data_location: String::from("./logs/mnist"), // Leak the commit object, since the Tonic server must have only 'static references. commit: Box::leak(Box::new(commit)), } } + #[tokio::test] + async fn test_get_experiment() { + let commit = CommitBuilder::new().build(); + let handler = sample_handler(commit); + let req = Request::new(data::GetExperimentRequest { + experiment_id: "123".to_string(), + }); + let res = handler.get_experiment(req).await.unwrap().into_inner(); + assert_eq!(res.data_location, "./logs/mnist"); // from `sample_handler` + assert_eq!(res.creation_time, None); + } + #[tokio::test] async fn test_list_plugins() { let commit = CommitBuilder::new()