Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tensorboard/data/grpc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,25 @@ def __init__(self, addr, stub):
self._stub = stub

def data_location(self, ctx, *, experiment_id):
return "grpc://%s" % (self._addr,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed, just musing that it was kind of useful actually to see this when testing Rustboard (easy to remember which grpc port it was on). I wonder if maybe it'd be worth formatting the data location as something like <logdir> served by grpc://<address>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I also miss it. I feel weakly inclined to keep it to just the
logdir to avoid confusing users. It’s nice to be able to say that
--load_fast is meant to be a straight upgrade.

You can get the address from --verbosity 0 or RUST_LOG, and I’d be
happy to patch /data/environment to return str(data_provider) for
debug use. Could even let that str be rendered on hover, if we want to
be really aggressive. That doesn’t cover the use case of “a user sent us
only a screenshot and we want to divine what data provider they’re
using”, but perhaps it gets most of the value—what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of patching /data/environment. If we want user screenshots to be more useful probably the first thing we want is a version number anyway, so it's not critical that it be visible. But nice if the frontend has some way to indicate to a curious soul "what DataProvider am I actually talking to".

req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = experiment_id
with _translate_grpc_error():
res = self._stub.GetExperiment(req)
return res.data_location

def experiment_metadata(self, ctx, *, experiment_id):
req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = experiment_id
with _translate_grpc_error():
res = self._stub.GetExperiment(req)
if not (res.name or res.description or res.HasField("creation_time")):
return None
res = provider.ExperimentMetadata(
experiment_name=res.name,
experiment_description=res.description,
creation_time=_timestamp_proto_to_float(res.creation_time),
)
return res

def list_plugins(self, ctx, *, experiment_id):
req = data_provider_pb2.ListPluginsRequest()
Expand Down Expand Up @@ -306,3 +324,8 @@ def _populate_rtf(run_tag_filter, rtf_proto):
rtf_proto.runs.names[:] = sorted(run_tag_filter.runs)
if run_tag_filter.tags is not None:
rtf_proto.tags.names[:] = sorted(run_tag_filter.tags)


def _timestamp_proto_to_float(ts):
"""Converts `timestamp_pb2.Timestamp` to float seconds since epoch."""
return ts.ToNanoseconds() / 1e9
67 changes: 66 additions & 1 deletion tensorboard/data/grpc_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,73 @@ def setUp(self):
self.ctx = context.RequestContext()

def test_data_location(self):
res = data_provider_pb2.GetExperimentResponse()
res.data_location = "./logs/mnist"
self.stub.GetExperiment.return_value = res

actual = self.provider.data_location(self.ctx, experiment_id="123")
self.assertEqual(actual, "grpc://localhost:0")
self.assertEqual(actual, "./logs/mnist")

req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = "123"
self.stub.GetExperiment.assert_called_once_with(req)

def test_experiment_metadata_when_only_data_location_set(self):
res = data_provider_pb2.GetExperimentResponse()
self.stub.GetExperiment.return_value = res

actual = self.provider.experiment_metadata(
self.ctx, experiment_id="123"
)
self.assertIsNone(actual)

req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = "123"
self.stub.GetExperiment.assert_called_once_with(req)

def test_experiment_metadata_with_partial_metadata(self):
res = data_provider_pb2.GetExperimentResponse()
res.name = "mnist"
self.stub.GetExperiment.return_value = res

actual = self.provider.experiment_metadata(
self.ctx, experiment_id="123"
)
self.assertEqual(
actual,
provider.ExperimentMetadata(
experiment_name="mnist",
experiment_description="",
creation_time=0,
),
)

req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = "123"
self.stub.GetExperiment.assert_called_once_with(req)

def test_experiment_metadata_with_creation_time(self):
res = data_provider_pb2.GetExperimentResponse()
res.name = "mnist"
res.description = "big breakthroughs"
res.creation_time.FromMilliseconds(1500)
self.stub.GetExperiment.return_value = res

actual = self.provider.experiment_metadata(
self.ctx, experiment_id="123"
)
self.assertEqual(
actual,
provider.ExperimentMetadata(
experiment_name="mnist",
experiment_description="big breakthroughs",
creation_time=1.5,
),
)

req = data_provider_pb2.GetExperimentRequest()
req.experiment_id = "123"
self.stub.GetExperiment.assert_called_once_with(req)

def test_list_plugins(self):
res = data_provider_pb2.ListPluginsResponse()
Expand Down
29 changes: 29 additions & 0 deletions tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,35 @@ def experiment_description(self):
def creation_time(self):
return self._creation_time

def __eq__(self, other):
if not isinstance(other, ExperimentMetadata):
return False
if self._experiment_name != other._experiment_name:
return False
if self._experiment_description != other._experiment_description:
return False
if self._creation_time != other._creation_time:
return False
return True

def __hash__(self):
return hash(
(
self._experiment_name,
self._experiment_description,
self._creation_time,
)
)

def __repr__(self):
return "ExperimentMetadata(%s)" % ", ".join(
(
"experiment_name=%r" % (self._experiment_name,),
"experiment_description=%r" % (self._experiment_description,),
"creation_time=%r" % (self._creation_time,),
)
)


class Run(object):
"""Metadata about a run.
Expand Down
25 changes: 25 additions & 0 deletions tensorboard/data/provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ def test_attributes(self):
self.assertEqual(e1.experiment_description, "Experiment on Foo")
self.assertEqual(e1.creation_time, 1.25)

def test_eq(self):
def md(**kwargs):
kwargs.setdefault("experiment_name", "FooExperiment")
kwargs.setdefault("experiment_description", "Experiment on Foo")
kwargs.setdefault("creation_time", 1.25)
return provider.ExperimentMetadata(**kwargs)

a1 = md()
a2 = md()
b = md(experiment_name="BarExperiment")
self.assertEqual(a1, a2)
self.assertNotEqual(a1, b)
self.assertNotEqual(b, object())

def test_repr(self):
x = provider.ExperimentMetadata(
experiment_name="FooExperiment",
experiment_description="Experiment on Foo",
creation_time=1.25,
)
repr_ = repr(x)
self.assertIn(repr(x.experiment_name), repr_)
self.assertIn(repr(x.experiment_description), repr_)
self.assertIn(repr(x.creation_time), repr_)


class RunTest(tb_test.TestCase):
def test_eq(self):
Expand Down
7 changes: 6 additions & 1 deletion tensorboard/data/server/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("listening on {:?}", bound);
}

let data_location = opts.logdir.display().to_string();

// Leak the commit object, since the Tonic server must have only 'static references. This only
// leaks the outer commit structure (of constant size), not the pointers to the actual data.
let commit: &'static Commit = Box::leak(Box::new(Commit::new()));
Expand Down Expand Up @@ -215,7 +217,10 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
})
.expect("failed to spawn reloader thread");

let handler = DataProviderHandler { commit };
let handler = DataProviderHandler {
data_location,
commit,
};
Server::builder()
.add_service(TensorBoardDataProviderServer::new(handler))
.serve_with_incoming(TcpListenerStream::new(listener))
Expand Down
19 changes: 18 additions & 1 deletion tensorboard/data/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use data::tensor_board_data_provider_server::TensorBoardDataProvider;
/// Data provider gRPC service implementation.
#[derive(Debug)]
pub struct DataProviderHandler {
pub data_location: String,
pub commit: &'static Commit,
}

Expand All @@ -62,7 +63,10 @@ impl TensorBoardDataProvider for DataProviderHandler {
&self,
_request: Request<data::GetExperimentRequest>,
) -> Result<Response<data::GetExperimentResponse>, Status> {
Err(Status::unimplemented("not yet implemented"))
Ok(Response::new(data::GetExperimentResponse {
data_location: self.data_location.clone(),
..Default::default()
}))
}

async fn list_plugins(
Expand Down Expand Up @@ -641,11 +645,24 @@ mod tests {

fn sample_handler(commit: Commit) -> DataProviderHandler {
DataProviderHandler {
data_location: String::from("./logs/mnist"),
// Leak the commit object, since the Tonic server must have only 'static references.
commit: Box::leak(Box::new(commit)),
}
}

#[tokio::test]
async fn test_get_experiment() {
let commit = CommitBuilder::new().build();
let handler = sample_handler(commit);
let req = Request::new(data::GetExperimentRequest {
experiment_id: "123".to_string(),
});
let res = handler.get_experiment(req).await.unwrap().into_inner();
assert_eq!(res.data_location, "./logs/mnist"); // from `sample_handler`
assert_eq!(res.creation_time, None);
}

#[tokio::test]
async fn test_list_plugins() {
let commit = CommitBuilder::new()
Expand Down