diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index d91415547c..1c73945e59 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -165,6 +165,18 @@ def _create_request_sender( ) +def _create_scalar_request_sender( + experiment_id=None, api=None, +): + if api is _USE_DEFAULT: + api = _create_mock_client() + return uploader_lib._ScalarBatchedRequestSender( + experiment_id=experiment_id, + api=api, + rpc_rate_limiter=util.RateLimiter(0), + ) + + class TensorboardUploaderTest(tf.test.TestCase): def test_create_experiment(self): logdir = "/logs/foo" @@ -564,42 +576,6 @@ def test_upload_swallows_rpc_failure(self): uploader._upload_once() mock_client.WriteScalar.assert_called_once() - def test_upload_propagates_experiment_deletion(self): - logdir = self.get_temp_dir() - with tb_test_util.FileWriter(logdir) as writer: - writer.add_test_summary("foo") - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) - uploader.create_experiment() - error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") - mock_client.WriteScalar.side_effect = error - with self.assertRaises(uploader_lib.ExperimentNotFoundError): - uploader._upload_once() - - def test_upload_preserves_wall_time(self): - logdir = self.get_temp_dir() - with tb_test_util.FileWriter(logdir) as writer: - # Add a raw event so we can specify the wall_time value deterministically. - writer.add_event( - event_pb2.Event( - step=1, - wall_time=123.123123123, - summary=scalar_v2.scalar_pb("foo", 5.0), - ) - ) - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) - uploader.create_experiment() - uploader._upload_once() - mock_client.WriteScalar.assert_called_once() - request = mock_client.WriteScalar.call_args[0][0] - # Just check the wall_time value; everything else is covered in the full - # logdir test below. - self.assertEqual( - 123123123123, - request.runs[0].tags[0].points[0].wall_time.ToNanoseconds(), - ) - def test_upload_full_logdir(self): logdir = self.get_temp_dir() mock_client = _create_mock_client() @@ -735,41 +711,6 @@ def test_empty_events(self): run_proto, write_service_pb2.WriteScalarRequest.Run() ) - def test_aggregation_by_tag(self): - def make_event(step, wall_time, tag, value): - return event_pb2.Event( - step=step, - wall_time=wall_time, - summary=scalar_v2.scalar_pb(tag, value), - ) - - events = [ - make_event(1, 1.0, "one", 11.0), - make_event(1, 2.0, "two", 22.0), - make_event(2, 3.0, "one", 33.0), - make_event(2, 4.0, "two", 44.0), - make_event( - 1, 5.0, "one", 55.0 - ), # Should preserve duplicate step=1. - make_event(1, 6.0, "three", 66.0), - ] - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, events) - tag_data = { - tag.name: [ - (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points - ] - for tag in run_proto.tags - } - self.assertEqual( - tag_data, - { - "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], - "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], - "three": [(1, 6.0, 66.0)], - }, - ) - def test_skips_non_scalar_events(self): events = [ event_pb2.Event(file_version="brain.Event:2"), @@ -838,9 +779,11 @@ def test_remembers_first_metadata_in_scalar_time_series(self): tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} self.assertEqual(tag_counts, {"loss": 2}) - def test_v1_summary_single_value(self): + def test_expands_multiple_values_in_event(self): event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=5.0) + event.summary.value.add(tag="foo", simple_value=1.0) + event.summary.value.add(tag="foo", simple_value=2.0) + event.summary.value.add(tag="foo", simple_value=3.0) run_proto = write_service_pb2.WriteScalarRequest.Run() self._populate_run_from_events(run_proto, [event]) expected_run_proto = write_service_pb2.WriteScalarRequest.Run() @@ -850,17 +793,74 @@ def test_v1_summary_single_value(self): foo_tag.metadata.plugin_data.plugin_name = "scalars" foo_tag.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0 + step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0 + ) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0 + ) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0 ) self.assertProtoEquals(run_proto, expected_run_proto) - def test_v1_summary_multiple_value(self): + +class ScalarBatchedRequestSenderTest(tf.test.TestCase): + def _add_events(self, sender, run_name, events): + for event in events: + for value in event.summary.value: + sender.add_event(run_name, event, value, value.metadata) + + def _add_events_and_flush(self, events): + mock_client = _create_mock_client() + sender = _create_scalar_request_sender( + experiment_id="123", api=mock_client, + ) + self._add_events(sender, "", events) + sender.flush() + + requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] + self.assertLen(requests, 1) + self.assertLen(requests[0].runs, 1) + return requests[0].runs[0] + + def test_aggregation_by_tag(self): + def make_event(step, wall_time, tag, value): + return event_pb2.Event( + step=step, + wall_time=wall_time, + summary=scalar_v2.scalar_pb(tag, value), + ) + + events = [ + make_event(1, 1.0, "one", 11.0), + make_event(1, 2.0, "two", 22.0), + make_event(2, 3.0, "one", 33.0), + make_event(2, 4.0, "two", 44.0), + make_event( + 1, 5.0, "one", 55.0 + ), # Should preserve duplicate step=1. + make_event(1, 6.0, "three", 66.0), + ] + run_proto = self._add_events_and_flush(events) + tag_data = { + tag.name: [ + (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points + ] + for tag in run_proto.tags + } + self.assertEqual( + tag_data, + { + "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], + "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], + "three": [(1, 6.0, 66.0)], + }, + ) + + def test_v1_summary(self): event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=1.0) - event.summary.value.add(tag="foo", simple_value=2.0) - event.summary.value.add(tag="foo", simple_value=3.0) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) + event.summary.value.add(tag="foo", simple_value=5.0) + run_proto = self._add_events_and_flush(_apply_compat([event])) expected_run_proto = write_service_pb2.WriteScalarRequest.Run() foo_tag = expected_run_proto.tags.add() foo_tag.name = "foo" @@ -868,13 +868,7 @@ def test_v1_summary_multiple_value(self): foo_tag.metadata.plugin_data.plugin_name = "scalars" foo_tag.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0 - ) - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0 - ) - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0 + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0 ) self.assertProtoEquals(run_proto, expected_run_proto) @@ -884,8 +878,7 @@ def test_v1_summary_tb_summary(self): tf_summary.SerializeToString() ) event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) + run_proto = self._add_events_and_flush(_apply_compat([event])) expected_run_proto = write_service_pb2.WriteScalarRequest.Run() foo_tag = expected_run_proto.tags.add() foo_tag.name = "foo/scalar_summary" @@ -901,8 +894,7 @@ def test_v2_summary(self): event = event_pb2.Event( step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0) ) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) + run_proto = self._add_events_and_flush(_apply_compat([event])) expected_run_proto = write_service_pb2.WriteScalarRequest.Run() foo_tag = expected_run_proto.tags.add() foo_tag.name = "foo" @@ -913,16 +905,26 @@ def test_v2_summary(self): ) self.assertProtoEquals(run_proto, expected_run_proto) + def test_propagates_experiment_deletion(self): + event = event_pb2.Event(step=1) + event.summary.value.add(tag="foo", simple_value=1.0) + + mock_client = _create_mock_client() + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, "run", _apply_compat([event])) + + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.WriteScalar.side_effect = error + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + sender.flush() + def test_no_budget_for_experiment_id(self): mock_client = _create_mock_client() - event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=1.0) - run_to_events = {"run_name": [event]} long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES - mock_client = _create_mock_client() with self.assertRaises(RuntimeError) as cm: - builder = _create_request_sender(long_experiment_id, mock_client) - builder.send_requests(run_to_events) + _create_scalar_request_sender( + experiment_id=long_experiment_id, api=mock_client, + ) self.assertEqual( str(cm.exception), "Byte budget too small for experiment ID" ) @@ -932,10 +934,9 @@ def test_no_room_for_single_point(self): event = event_pb2.Event(step=1, wall_time=123.456) event.summary.value.add(tag="foo", simple_value=1.0) long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES - run_to_events = {long_run_name: _apply_compat([event])} with self.assertRaises(RuntimeError) as cm: - builder = _create_request_sender("123", mock_client) - builder.send_requests(run_to_events) + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, long_run_name, [event]) self.assertEqual(str(cm.exception), "add_event failed despite flush") @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) @@ -948,20 +949,17 @@ def test_break_at_run_boundary(self): event_1.summary.value.add(tag="foo", simple_value=1.0) event_2 = event_pb2.Event(step=2) event_2.summary.value.add(tag="bar", simple_value=-2.0) - run_to_events = collections.OrderedDict( - [ - (long_run_1, _apply_compat([event_1])), - (long_run_2, _apply_compat([event_2])), - ] - ) - builder = _create_request_sender("123", mock_client) - builder.send_requests(run_to_events) + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, long_run_1, _apply_compat([event_1])) + self._add_events(sender, long_run_2, _apply_compat([event_2])) + sender.flush() requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] for request in requests: _clear_wall_times(request) + # Expect two RPC calls despite a single explicit call to flush(). expected = [ write_service_pb2.WriteScalarRequest(experiment_id="123"), write_service_pb2.WriteScalarRequest(experiment_id="123"), @@ -990,14 +988,15 @@ def test_break_at_tag_boundary(self): event = event_pb2.Event(step=1) event.summary.value.add(tag=long_tag_1, simple_value=1.0) event.summary.value.add(tag=long_tag_2, simple_value=2.0) - run_to_events = {"train": _apply_compat([event])} - builder = _create_request_sender("123", mock_client) - builder.send_requests(run_to_events) + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, "train", _apply_compat([event])) + sender.flush() requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] for request in requests: _clear_wall_times(request) + # Expect two RPC calls despite a single explicit call to flush(). expected = [ write_service_pb2.WriteScalarRequest(experiment_id="123"), write_service_pb2.WriteScalarRequest(experiment_id="123"), @@ -1030,10 +1029,10 @@ def test_break_at_scalar_point_boundary(self): if step > 0: summary.value[0].ClearField("metadata") events.append(event_pb2.Event(summary=summary, step=step)) - run_to_events = {"train": _apply_compat(events)} - builder = _create_request_sender("123", mock_client) - builder.send_requests(run_to_events) + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, "train", _apply_compat(events)) + sender.flush() requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] for request in requests: _clear_wall_times(request) @@ -1064,12 +1063,6 @@ def test_prunes_tags_and_runs(self): event_1.summary.value.add(tag="foo", simple_value=1.0) event_2 = event_pb2.Event(step=2) event_2.summary.value.add(tag="bar", simple_value=-2.0) - run_to_events = collections.OrderedDict( - [ - ("train", _apply_compat([event_1])), - ("test", _apply_compat([event_2])), - ] - ) real_create_point = ( uploader_lib._ScalarBatchedRequestSender._create_point @@ -1090,8 +1083,10 @@ def mock_create_point(uploader_self, *args, **kwargs): "_create_point", mock_create_point, ): - builder = _create_request_sender("123", mock_client) - builder.send_requests(run_to_events) + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, "train", _apply_compat([event_1])) + self._add_events(sender, "test", _apply_compat([event_2])) + sender.flush() requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] for request in requests: _clear_wall_times(request) @@ -1116,15 +1111,14 @@ def mock_create_point(uploader_self, *args, **kwargs): def test_wall_time_precision(self): # Test a wall time that is exactly representable in float64 but has enough - # digits to incur error if converted to nanonseconds the naive way (* 1e9). + # digits to incur error if converted to nanoseconds the naive way (* 1e9). event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) event1.summary.value.add(tag="foo", simple_value=1.0) # Test a wall time where as a float64, the fractional part on its own will # introduce error if truncated to 9 decimal places instead of rounded. event2 = event_pb2.Event(step=2, wall_time=1.000000002) event2.summary.value.add(tag="foo", simple_value=2.0) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event1, event2]) + run_proto = self._add_events_and_flush(_apply_compat([event1, event2])) self.assertEqual( test_util.timestamp_pb(1567808404765432119), run_proto.tags[0].points[0].wall_time,