diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index c720edaf5f..f7bb34351d 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import collections +import itertools import os import grpc @@ -81,6 +82,12 @@ def _create_mock_client(): experiment_id="123", url="should not be used!" ) mock_client.CreateExperiment.return_value = fake_exp_response + mock_client.GetOrCreateBlobSequence.side_effect = ( + write_service_pb2.GetOrCreateBlobSequenceResponse( + blob_sequence_id="blob%d" % i + ) + for i in itertools.count() + ) return mock_client @@ -281,7 +288,6 @@ def test_start_uploading_graphs(self): graph_event = event_pb2.Event( graph_def=_create_example_graph_bytes(950) ) - mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) mock_logdir_loader.get_run_events.side_effect = [ { @@ -302,6 +308,13 @@ def test_start_uploading_graphs(self): uploader.start_uploading() self.assertEqual(1, mock_client.CreateExperiment.call_count) self.assertEqual(10, mock_client.WriteBlob.call_count) + for (i, call) in enumerate(mock_client.WriteBlob.call_args_list): + requests = list(call[0][0]) + data = b"".join(r.data for r in requests) + self.assertEqual(data, graph_event.graph_def) + self.assertEqual( + set(r.blob_sequence_id for r in requests), {"blob%d" % i}, + ) self.assertEqual(0, mock_rate_limiter.tick.call_count) self.assertEqual(10, mock_blob_rate_limiter.tick.call_count)