Skip to content

Commit 0aee979

Browse files
authored
Refactor tests for uploader._ScalarBatchedRequestSender (#3532)
Identify tests in uploader_test.py that specifically target logic in _ScalarBatchedRequestSender and rewrite them to only use _ScalarBatchedRequestSender and none of the other layers in uploader.py. This is a useful exercise for the work on _TensorBatchedRequestSender. I was having a difficult time using _ScalarBatchedRequestSender tests to establish a pattern for testing _TensorBatchedRequestSender.
1 parent 4fe081c commit 0aee979

File tree

1 file changed

+119
-125
lines changed

1 file changed

+119
-125
lines changed

tensorboard/uploader/uploader_test.py

Lines changed: 119 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ def _create_request_sender(
165165
)
166166

167167

168+
def _create_scalar_request_sender(
169+
experiment_id=None, api=None,
170+
):
171+
if api is _USE_DEFAULT:
172+
api = _create_mock_client()
173+
return uploader_lib._ScalarBatchedRequestSender(
174+
experiment_id=experiment_id,
175+
api=api,
176+
rpc_rate_limiter=util.RateLimiter(0),
177+
)
178+
179+
168180
class TensorboardUploaderTest(tf.test.TestCase):
169181
def test_create_experiment(self):
170182
logdir = "/logs/foo"
@@ -564,42 +576,6 @@ def test_upload_swallows_rpc_failure(self):
564576
uploader._upload_once()
565577
mock_client.WriteScalar.assert_called_once()
566578

567-
def test_upload_propagates_experiment_deletion(self):
568-
logdir = self.get_temp_dir()
569-
with tb_test_util.FileWriter(logdir) as writer:
570-
writer.add_test_summary("foo")
571-
mock_client = _create_mock_client()
572-
uploader = _create_uploader(mock_client, logdir)
573-
uploader.create_experiment()
574-
error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
575-
mock_client.WriteScalar.side_effect = error
576-
with self.assertRaises(uploader_lib.ExperimentNotFoundError):
577-
uploader._upload_once()
578-
579-
def test_upload_preserves_wall_time(self):
580-
logdir = self.get_temp_dir()
581-
with tb_test_util.FileWriter(logdir) as writer:
582-
# Add a raw event so we can specify the wall_time value deterministically.
583-
writer.add_event(
584-
event_pb2.Event(
585-
step=1,
586-
wall_time=123.123123123,
587-
summary=scalar_v2.scalar_pb("foo", 5.0),
588-
)
589-
)
590-
mock_client = _create_mock_client()
591-
uploader = _create_uploader(mock_client, logdir)
592-
uploader.create_experiment()
593-
uploader._upload_once()
594-
mock_client.WriteScalar.assert_called_once()
595-
request = mock_client.WriteScalar.call_args[0][0]
596-
# Just check the wall_time value; everything else is covered in the full
597-
# logdir test below.
598-
self.assertEqual(
599-
123123123123,
600-
request.runs[0].tags[0].points[0].wall_time.ToNanoseconds(),
601-
)
602-
603579
def test_upload_full_logdir(self):
604580
logdir = self.get_temp_dir()
605581
mock_client = _create_mock_client()
@@ -735,41 +711,6 @@ def test_empty_events(self):
735711
run_proto, write_service_pb2.WriteScalarRequest.Run()
736712
)
737713

738-
def test_aggregation_by_tag(self):
739-
def make_event(step, wall_time, tag, value):
740-
return event_pb2.Event(
741-
step=step,
742-
wall_time=wall_time,
743-
summary=scalar_v2.scalar_pb(tag, value),
744-
)
745-
746-
events = [
747-
make_event(1, 1.0, "one", 11.0),
748-
make_event(1, 2.0, "two", 22.0),
749-
make_event(2, 3.0, "one", 33.0),
750-
make_event(2, 4.0, "two", 44.0),
751-
make_event(
752-
1, 5.0, "one", 55.0
753-
), # Should preserve duplicate step=1.
754-
make_event(1, 6.0, "three", 66.0),
755-
]
756-
run_proto = write_service_pb2.WriteScalarRequest.Run()
757-
self._populate_run_from_events(run_proto, events)
758-
tag_data = {
759-
tag.name: [
760-
(p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points
761-
]
762-
for tag in run_proto.tags
763-
}
764-
self.assertEqual(
765-
tag_data,
766-
{
767-
"one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)],
768-
"two": [(1, 2.0, 22.0), (2, 4.0, 44.0)],
769-
"three": [(1, 6.0, 66.0)],
770-
},
771-
)
772-
773714
def test_skips_non_scalar_events(self):
774715
events = [
775716
event_pb2.Event(file_version="brain.Event:2"),
@@ -838,9 +779,11 @@ def test_remembers_first_metadata_in_scalar_time_series(self):
838779
tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags}
839780
self.assertEqual(tag_counts, {"loss": 2})
840781

841-
def test_v1_summary_single_value(self):
782+
def test_expands_multiple_values_in_event(self):
842783
event = event_pb2.Event(step=1, wall_time=123.456)
843-
event.summary.value.add(tag="foo", simple_value=5.0)
784+
event.summary.value.add(tag="foo", simple_value=1.0)
785+
event.summary.value.add(tag="foo", simple_value=2.0)
786+
event.summary.value.add(tag="foo", simple_value=3.0)
844787
run_proto = write_service_pb2.WriteScalarRequest.Run()
845788
self._populate_run_from_events(run_proto, [event])
846789
expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
@@ -850,31 +793,82 @@ def test_v1_summary_single_value(self):
850793
foo_tag.metadata.plugin_data.plugin_name = "scalars"
851794
foo_tag.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR
852795
foo_tag.points.add(
853-
step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0
796+
step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0
797+
)
798+
foo_tag.points.add(
799+
step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0
800+
)
801+
foo_tag.points.add(
802+
step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0
854803
)
855804
self.assertProtoEquals(run_proto, expected_run_proto)
856805

857-
def test_v1_summary_multiple_value(self):
806+
807+
class ScalarBatchedRequestSenderTest(tf.test.TestCase):
808+
def _add_events(self, sender, run_name, events):
809+
for event in events:
810+
for value in event.summary.value:
811+
sender.add_event(run_name, event, value, value.metadata)
812+
813+
def _add_events_and_flush(self, events):
814+
mock_client = _create_mock_client()
815+
sender = _create_scalar_request_sender(
816+
experiment_id="123", api=mock_client,
817+
)
818+
self._add_events(sender, "", events)
819+
sender.flush()
820+
821+
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
822+
self.assertLen(requests, 1)
823+
self.assertLen(requests[0].runs, 1)
824+
return requests[0].runs[0]
825+
826+
def test_aggregation_by_tag(self):
827+
def make_event(step, wall_time, tag, value):
828+
return event_pb2.Event(
829+
step=step,
830+
wall_time=wall_time,
831+
summary=scalar_v2.scalar_pb(tag, value),
832+
)
833+
834+
events = [
835+
make_event(1, 1.0, "one", 11.0),
836+
make_event(1, 2.0, "two", 22.0),
837+
make_event(2, 3.0, "one", 33.0),
838+
make_event(2, 4.0, "two", 44.0),
839+
make_event(
840+
1, 5.0, "one", 55.0
841+
), # Should preserve duplicate step=1.
842+
make_event(1, 6.0, "three", 66.0),
843+
]
844+
run_proto = self._add_events_and_flush(events)
845+
tag_data = {
846+
tag.name: [
847+
(p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points
848+
]
849+
for tag in run_proto.tags
850+
}
851+
self.assertEqual(
852+
tag_data,
853+
{
854+
"one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)],
855+
"two": [(1, 2.0, 22.0), (2, 4.0, 44.0)],
856+
"three": [(1, 6.0, 66.0)],
857+
},
858+
)
859+
860+
def test_v1_summary(self):
858861
event = event_pb2.Event(step=1, wall_time=123.456)
859-
event.summary.value.add(tag="foo", simple_value=1.0)
860-
event.summary.value.add(tag="foo", simple_value=2.0)
861-
event.summary.value.add(tag="foo", simple_value=3.0)
862-
run_proto = write_service_pb2.WriteScalarRequest.Run()
863-
self._populate_run_from_events(run_proto, [event])
862+
event.summary.value.add(tag="foo", simple_value=5.0)
863+
run_proto = self._add_events_and_flush(_apply_compat([event]))
864864
expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
865865
foo_tag = expected_run_proto.tags.add()
866866
foo_tag.name = "foo"
867867
foo_tag.metadata.display_name = "foo"
868868
foo_tag.metadata.plugin_data.plugin_name = "scalars"
869869
foo_tag.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR
870870
foo_tag.points.add(
871-
step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0
872-
)
873-
foo_tag.points.add(
874-
step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0
875-
)
876-
foo_tag.points.add(
877-
step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0
871+
step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0
878872
)
879873
self.assertProtoEquals(run_proto, expected_run_proto)
880874

@@ -884,8 +878,7 @@ def test_v1_summary_tb_summary(self):
884878
tf_summary.SerializeToString()
885879
)
886880
event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary)
887-
run_proto = write_service_pb2.WriteScalarRequest.Run()
888-
self._populate_run_from_events(run_proto, [event])
881+
run_proto = self._add_events_and_flush(_apply_compat([event]))
889882
expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
890883
foo_tag = expected_run_proto.tags.add()
891884
foo_tag.name = "foo/scalar_summary"
@@ -901,8 +894,7 @@ def test_v2_summary(self):
901894
event = event_pb2.Event(
902895
step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0)
903896
)
904-
run_proto = write_service_pb2.WriteScalarRequest.Run()
905-
self._populate_run_from_events(run_proto, [event])
897+
run_proto = self._add_events_and_flush(_apply_compat([event]))
906898
expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
907899
foo_tag = expected_run_proto.tags.add()
908900
foo_tag.name = "foo"
@@ -913,16 +905,26 @@ def test_v2_summary(self):
913905
)
914906
self.assertProtoEquals(run_proto, expected_run_proto)
915907

908+
def test_propagates_experiment_deletion(self):
909+
event = event_pb2.Event(step=1)
910+
event.summary.value.add(tag="foo", simple_value=1.0)
911+
912+
mock_client = _create_mock_client()
913+
sender = _create_scalar_request_sender("123", mock_client)
914+
self._add_events(sender, "run", _apply_compat([event]))
915+
916+
error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
917+
mock_client.WriteScalar.side_effect = error
918+
with self.assertRaises(uploader_lib.ExperimentNotFoundError):
919+
sender.flush()
920+
916921
def test_no_budget_for_experiment_id(self):
917922
mock_client = _create_mock_client()
918-
event = event_pb2.Event(step=1, wall_time=123.456)
919-
event.summary.value.add(tag="foo", simple_value=1.0)
920-
run_to_events = {"run_name": [event]}
921923
long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
922-
mock_client = _create_mock_client()
923924
with self.assertRaises(RuntimeError) as cm:
924-
builder = _create_request_sender(long_experiment_id, mock_client)
925-
builder.send_requests(run_to_events)
925+
_create_scalar_request_sender(
926+
experiment_id=long_experiment_id, api=mock_client,
927+
)
926928
self.assertEqual(
927929
str(cm.exception), "Byte budget too small for experiment ID"
928930
)
@@ -932,10 +934,9 @@ def test_no_room_for_single_point(self):
932934
event = event_pb2.Event(step=1, wall_time=123.456)
933935
event.summary.value.add(tag="foo", simple_value=1.0)
934936
long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
935-
run_to_events = {long_run_name: _apply_compat([event])}
936937
with self.assertRaises(RuntimeError) as cm:
937-
builder = _create_request_sender("123", mock_client)
938-
builder.send_requests(run_to_events)
938+
sender = _create_scalar_request_sender("123", mock_client)
939+
self._add_events(sender, long_run_name, [event])
939940
self.assertEqual(str(cm.exception), "add_event failed despite flush")
940941

941942
@mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024)
@@ -948,20 +949,17 @@ def test_break_at_run_boundary(self):
948949
event_1.summary.value.add(tag="foo", simple_value=1.0)
949950
event_2 = event_pb2.Event(step=2)
950951
event_2.summary.value.add(tag="bar", simple_value=-2.0)
951-
run_to_events = collections.OrderedDict(
952-
[
953-
(long_run_1, _apply_compat([event_1])),
954-
(long_run_2, _apply_compat([event_2])),
955-
]
956-
)
957952

958-
builder = _create_request_sender("123", mock_client)
959-
builder.send_requests(run_to_events)
953+
sender = _create_scalar_request_sender("123", mock_client)
954+
self._add_events(sender, long_run_1, _apply_compat([event_1]))
955+
self._add_events(sender, long_run_2, _apply_compat([event_2]))
956+
sender.flush()
960957
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
961958

962959
for request in requests:
963960
_clear_wall_times(request)
964961

962+
# Expect two RPC calls despite a single explicit call to flush().
965963
expected = [
966964
write_service_pb2.WriteScalarRequest(experiment_id="123"),
967965
write_service_pb2.WriteScalarRequest(experiment_id="123"),
@@ -990,14 +988,15 @@ def test_break_at_tag_boundary(self):
990988
event = event_pb2.Event(step=1)
991989
event.summary.value.add(tag=long_tag_1, simple_value=1.0)
992990
event.summary.value.add(tag=long_tag_2, simple_value=2.0)
993-
run_to_events = {"train": _apply_compat([event])}
994991

995-
builder = _create_request_sender("123", mock_client)
996-
builder.send_requests(run_to_events)
992+
sender = _create_scalar_request_sender("123", mock_client)
993+
self._add_events(sender, "train", _apply_compat([event]))
994+
sender.flush()
997995
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
998996
for request in requests:
999997
_clear_wall_times(request)
1000998

999+
# Expect two RPC calls despite a single explicit call to flush().
10011000
expected = [
10021001
write_service_pb2.WriteScalarRequest(experiment_id="123"),
10031002
write_service_pb2.WriteScalarRequest(experiment_id="123"),
@@ -1030,10 +1029,10 @@ def test_break_at_scalar_point_boundary(self):
10301029
if step > 0:
10311030
summary.value[0].ClearField("metadata")
10321031
events.append(event_pb2.Event(summary=summary, step=step))
1033-
run_to_events = {"train": _apply_compat(events)}
10341032

1035-
builder = _create_request_sender("123", mock_client)
1036-
builder.send_requests(run_to_events)
1033+
sender = _create_scalar_request_sender("123", mock_client)
1034+
self._add_events(sender, "train", _apply_compat(events))
1035+
sender.flush()
10371036
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
10381037
for request in requests:
10391038
_clear_wall_times(request)
@@ -1064,12 +1063,6 @@ def test_prunes_tags_and_runs(self):
10641063
event_1.summary.value.add(tag="foo", simple_value=1.0)
10651064
event_2 = event_pb2.Event(step=2)
10661065
event_2.summary.value.add(tag="bar", simple_value=-2.0)
1067-
run_to_events = collections.OrderedDict(
1068-
[
1069-
("train", _apply_compat([event_1])),
1070-
("test", _apply_compat([event_2])),
1071-
]
1072-
)
10731066

10741067
real_create_point = (
10751068
uploader_lib._ScalarBatchedRequestSender._create_point
@@ -1090,8 +1083,10 @@ def mock_create_point(uploader_self, *args, **kwargs):
10901083
"_create_point",
10911084
mock_create_point,
10921085
):
1093-
builder = _create_request_sender("123", mock_client)
1094-
builder.send_requests(run_to_events)
1086+
sender = _create_scalar_request_sender("123", mock_client)
1087+
self._add_events(sender, "train", _apply_compat([event_1]))
1088+
self._add_events(sender, "test", _apply_compat([event_2]))
1089+
sender.flush()
10951090
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
10961091
for request in requests:
10971092
_clear_wall_times(request)
@@ -1116,15 +1111,14 @@ def mock_create_point(uploader_self, *args, **kwargs):
11161111

11171112
def test_wall_time_precision(self):
11181113
# Test a wall time that is exactly representable in float64 but has enough
1119-
# digits to incur error if converted to nanonseconds the naive way (* 1e9).
1114+
# digits to incur error if converted to nanoseconds the naive way (* 1e9).
11201115
event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119)
11211116
event1.summary.value.add(tag="foo", simple_value=1.0)
11221117
# Test a wall time where as a float64, the fractional part on its own will
11231118
# introduce error if truncated to 9 decimal places instead of rounded.
11241119
event2 = event_pb2.Event(step=2, wall_time=1.000000002)
11251120
event2.summary.value.add(tag="foo", simple_value=2.0)
1126-
run_proto = write_service_pb2.WriteScalarRequest.Run()
1127-
self._populate_run_from_events(run_proto, [event1, event2])
1121+
run_proto = self._add_events_and_flush(_apply_compat([event1, event2]))
11281122
self.assertEqual(
11291123
test_util.timestamp_pb(1567808404765432119),
11301124
run_proto.tags[0].points[0].wall_time,

0 commit comments

Comments
 (0)