@@ -165,6 +165,18 @@ def _create_request_sender(
165165 )
166166
167167
168+ def _create_scalar_request_sender (
169+ experiment_id = None , api = None ,
170+ ):
171+ if api is _USE_DEFAULT :
172+ api = _create_mock_client ()
173+ return uploader_lib ._ScalarBatchedRequestSender (
174+ experiment_id = experiment_id ,
175+ api = api ,
176+ rpc_rate_limiter = util .RateLimiter (0 ),
177+ )
178+
179+
168180class TensorboardUploaderTest (tf .test .TestCase ):
169181 def test_create_experiment (self ):
170182 logdir = "/logs/foo"
@@ -564,42 +576,6 @@ def test_upload_swallows_rpc_failure(self):
564576 uploader ._upload_once ()
565577 mock_client .WriteScalar .assert_called_once ()
566578
567- def test_upload_propagates_experiment_deletion (self ):
568- logdir = self .get_temp_dir ()
569- with tb_test_util .FileWriter (logdir ) as writer :
570- writer .add_test_summary ("foo" )
571- mock_client = _create_mock_client ()
572- uploader = _create_uploader (mock_client , logdir )
573- uploader .create_experiment ()
574- error = test_util .grpc_error (grpc .StatusCode .NOT_FOUND , "nope" )
575- mock_client .WriteScalar .side_effect = error
576- with self .assertRaises (uploader_lib .ExperimentNotFoundError ):
577- uploader ._upload_once ()
578-
579- def test_upload_preserves_wall_time (self ):
580- logdir = self .get_temp_dir ()
581- with tb_test_util .FileWriter (logdir ) as writer :
582- # Add a raw event so we can specify the wall_time value deterministically.
583- writer .add_event (
584- event_pb2 .Event (
585- step = 1 ,
586- wall_time = 123.123123123 ,
587- summary = scalar_v2 .scalar_pb ("foo" , 5.0 ),
588- )
589- )
590- mock_client = _create_mock_client ()
591- uploader = _create_uploader (mock_client , logdir )
592- uploader .create_experiment ()
593- uploader ._upload_once ()
594- mock_client .WriteScalar .assert_called_once ()
595- request = mock_client .WriteScalar .call_args [0 ][0 ]
596- # Just check the wall_time value; everything else is covered in the full
597- # logdir test below.
598- self .assertEqual (
599- 123123123123 ,
600- request .runs [0 ].tags [0 ].points [0 ].wall_time .ToNanoseconds (),
601- )
602-
603579 def test_upload_full_logdir (self ):
604580 logdir = self .get_temp_dir ()
605581 mock_client = _create_mock_client ()
@@ -735,41 +711,6 @@ def test_empty_events(self):
735711 run_proto , write_service_pb2 .WriteScalarRequest .Run ()
736712 )
737713
738- def test_aggregation_by_tag (self ):
739- def make_event (step , wall_time , tag , value ):
740- return event_pb2 .Event (
741- step = step ,
742- wall_time = wall_time ,
743- summary = scalar_v2 .scalar_pb (tag , value ),
744- )
745-
746- events = [
747- make_event (1 , 1.0 , "one" , 11.0 ),
748- make_event (1 , 2.0 , "two" , 22.0 ),
749- make_event (2 , 3.0 , "one" , 33.0 ),
750- make_event (2 , 4.0 , "two" , 44.0 ),
751- make_event (
752- 1 , 5.0 , "one" , 55.0
753- ), # Should preserve duplicate step=1.
754- make_event (1 , 6.0 , "three" , 66.0 ),
755- ]
756- run_proto = write_service_pb2 .WriteScalarRequest .Run ()
757- self ._populate_run_from_events (run_proto , events )
758- tag_data = {
759- tag .name : [
760- (p .step , p .wall_time .ToSeconds (), p .value ) for p in tag .points
761- ]
762- for tag in run_proto .tags
763- }
764- self .assertEqual (
765- tag_data ,
766- {
767- "one" : [(1 , 1.0 , 11.0 ), (2 , 3.0 , 33.0 ), (1 , 5.0 , 55.0 )],
768- "two" : [(1 , 2.0 , 22.0 ), (2 , 4.0 , 44.0 )],
769- "three" : [(1 , 6.0 , 66.0 )],
770- },
771- )
772-
773714 def test_skips_non_scalar_events (self ):
774715 events = [
775716 event_pb2 .Event (file_version = "brain.Event:2" ),
@@ -838,9 +779,11 @@ def test_remembers_first_metadata_in_scalar_time_series(self):
838779 tag_counts = {tag .name : len (tag .points ) for tag in run_proto .tags }
839780 self .assertEqual (tag_counts , {"loss" : 2 })
840781
841- def test_v1_summary_single_value (self ):
782+ def test_expands_multiple_values_in_event (self ):
842783 event = event_pb2 .Event (step = 1 , wall_time = 123.456 )
843- event .summary .value .add (tag = "foo" , simple_value = 5.0 )
784+ event .summary .value .add (tag = "foo" , simple_value = 1.0 )
785+ event .summary .value .add (tag = "foo" , simple_value = 2.0 )
786+ event .summary .value .add (tag = "foo" , simple_value = 3.0 )
844787 run_proto = write_service_pb2 .WriteScalarRequest .Run ()
845788 self ._populate_run_from_events (run_proto , [event ])
846789 expected_run_proto = write_service_pb2 .WriteScalarRequest .Run ()
@@ -850,31 +793,82 @@ def test_v1_summary_single_value(self):
850793 foo_tag .metadata .plugin_data .plugin_name = "scalars"
851794 foo_tag .metadata .data_class = summary_pb2 .DATA_CLASS_SCALAR
852795 foo_tag .points .add (
853- step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 5.0
796+ step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 1.0
797+ )
798+ foo_tag .points .add (
799+ step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 2.0
800+ )
801+ foo_tag .points .add (
802+ step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 3.0
854803 )
855804 self .assertProtoEquals (run_proto , expected_run_proto )
856805
857- def test_v1_summary_multiple_value (self ):
806+
807+ class ScalarBatchedRequestSenderTest (tf .test .TestCase ):
808+ def _add_events (self , sender , run_name , events ):
809+ for event in events :
810+ for value in event .summary .value :
811+ sender .add_event (run_name , event , value , value .metadata )
812+
813+ def _add_events_and_flush (self , events ):
814+ mock_client = _create_mock_client ()
815+ sender = _create_scalar_request_sender (
816+ experiment_id = "123" , api = mock_client ,
817+ )
818+ self ._add_events (sender , "" , events )
819+ sender .flush ()
820+
821+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
822+ self .assertLen (requests , 1 )
823+ self .assertLen (requests [0 ].runs , 1 )
824+ return requests [0 ].runs [0 ]
825+
826+ def test_aggregation_by_tag (self ):
827+ def make_event (step , wall_time , tag , value ):
828+ return event_pb2 .Event (
829+ step = step ,
830+ wall_time = wall_time ,
831+ summary = scalar_v2 .scalar_pb (tag , value ),
832+ )
833+
834+ events = [
835+ make_event (1 , 1.0 , "one" , 11.0 ),
836+ make_event (1 , 2.0 , "two" , 22.0 ),
837+ make_event (2 , 3.0 , "one" , 33.0 ),
838+ make_event (2 , 4.0 , "two" , 44.0 ),
839+ make_event (
840+ 1 , 5.0 , "one" , 55.0
841+ ), # Should preserve duplicate step=1.
842+ make_event (1 , 6.0 , "three" , 66.0 ),
843+ ]
844+ run_proto = self ._add_events_and_flush (events )
845+ tag_data = {
846+ tag .name : [
847+ (p .step , p .wall_time .ToSeconds (), p .value ) for p in tag .points
848+ ]
849+ for tag in run_proto .tags
850+ }
851+ self .assertEqual (
852+ tag_data ,
853+ {
854+ "one" : [(1 , 1.0 , 11.0 ), (2 , 3.0 , 33.0 ), (1 , 5.0 , 55.0 )],
855+ "two" : [(1 , 2.0 , 22.0 ), (2 , 4.0 , 44.0 )],
856+ "three" : [(1 , 6.0 , 66.0 )],
857+ },
858+ )
859+
860+ def test_v1_summary (self ):
858861 event = event_pb2 .Event (step = 1 , wall_time = 123.456 )
859- event .summary .value .add (tag = "foo" , simple_value = 1.0 )
860- event .summary .value .add (tag = "foo" , simple_value = 2.0 )
861- event .summary .value .add (tag = "foo" , simple_value = 3.0 )
862- run_proto = write_service_pb2 .WriteScalarRequest .Run ()
863- self ._populate_run_from_events (run_proto , [event ])
862+ event .summary .value .add (tag = "foo" , simple_value = 5.0 )
863+ run_proto = self ._add_events_and_flush (_apply_compat ([event ]))
864864 expected_run_proto = write_service_pb2 .WriteScalarRequest .Run ()
865865 foo_tag = expected_run_proto .tags .add ()
866866 foo_tag .name = "foo"
867867 foo_tag .metadata .display_name = "foo"
868868 foo_tag .metadata .plugin_data .plugin_name = "scalars"
869869 foo_tag .metadata .data_class = summary_pb2 .DATA_CLASS_SCALAR
870870 foo_tag .points .add (
871- step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 1.0
872- )
873- foo_tag .points .add (
874- step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 2.0
875- )
876- foo_tag .points .add (
877- step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 3.0
871+ step = 1 , wall_time = test_util .timestamp_pb (123456000000 ), value = 5.0
878872 )
879873 self .assertProtoEquals (run_proto , expected_run_proto )
880874
@@ -884,8 +878,7 @@ def test_v1_summary_tb_summary(self):
884878 tf_summary .SerializeToString ()
885879 )
886880 event = event_pb2 .Event (step = 1 , wall_time = 123.456 , summary = tb_summary )
887- run_proto = write_service_pb2 .WriteScalarRequest .Run ()
888- self ._populate_run_from_events (run_proto , [event ])
881+ run_proto = self ._add_events_and_flush (_apply_compat ([event ]))
889882 expected_run_proto = write_service_pb2 .WriteScalarRequest .Run ()
890883 foo_tag = expected_run_proto .tags .add ()
891884 foo_tag .name = "foo/scalar_summary"
@@ -901,8 +894,7 @@ def test_v2_summary(self):
901894 event = event_pb2 .Event (
902895 step = 1 , wall_time = 123.456 , summary = scalar_v2 .scalar_pb ("foo" , 5.0 )
903896 )
904- run_proto = write_service_pb2 .WriteScalarRequest .Run ()
905- self ._populate_run_from_events (run_proto , [event ])
897+ run_proto = self ._add_events_and_flush (_apply_compat ([event ]))
906898 expected_run_proto = write_service_pb2 .WriteScalarRequest .Run ()
907899 foo_tag = expected_run_proto .tags .add ()
908900 foo_tag .name = "foo"
@@ -913,16 +905,26 @@ def test_v2_summary(self):
913905 )
914906 self .assertProtoEquals (run_proto , expected_run_proto )
915907
908+ def test_propagates_experiment_deletion (self ):
909+ event = event_pb2 .Event (step = 1 )
910+ event .summary .value .add (tag = "foo" , simple_value = 1.0 )
911+
912+ mock_client = _create_mock_client ()
913+ sender = _create_scalar_request_sender ("123" , mock_client )
914+ self ._add_events (sender , "run" , _apply_compat ([event ]))
915+
916+ error = test_util .grpc_error (grpc .StatusCode .NOT_FOUND , "nope" )
917+ mock_client .WriteScalar .side_effect = error
918+ with self .assertRaises (uploader_lib .ExperimentNotFoundError ):
919+ sender .flush ()
920+
916921 def test_no_budget_for_experiment_id (self ):
917922 mock_client = _create_mock_client ()
918- event = event_pb2 .Event (step = 1 , wall_time = 123.456 )
919- event .summary .value .add (tag = "foo" , simple_value = 1.0 )
920- run_to_events = {"run_name" : [event ]}
921923 long_experiment_id = "A" * uploader_lib ._MAX_REQUEST_LENGTH_BYTES
922- mock_client = _create_mock_client ()
923924 with self .assertRaises (RuntimeError ) as cm :
924- builder = _create_request_sender (long_experiment_id , mock_client )
925- builder .send_requests (run_to_events )
925+ _create_scalar_request_sender (
926+ experiment_id = long_experiment_id , api = mock_client ,
927+ )
926928 self .assertEqual (
927929 str (cm .exception ), "Byte budget too small for experiment ID"
928930 )
@@ -932,10 +934,9 @@ def test_no_room_for_single_point(self):
932934 event = event_pb2 .Event (step = 1 , wall_time = 123.456 )
933935 event .summary .value .add (tag = "foo" , simple_value = 1.0 )
934936 long_run_name = "A" * uploader_lib ._MAX_REQUEST_LENGTH_BYTES
935- run_to_events = {long_run_name : _apply_compat ([event ])}
936937 with self .assertRaises (RuntimeError ) as cm :
937- builder = _create_request_sender ("123" , mock_client )
938- builder . send_requests ( run_to_events )
938+ sender = _create_scalar_request_sender ("123" , mock_client )
939+ self . _add_events ( sender , long_run_name , [ event ] )
939940 self .assertEqual (str (cm .exception ), "add_event failed despite flush" )
940941
941942 @mock .patch .object (uploader_lib , "_MAX_REQUEST_LENGTH_BYTES" , 1024 )
@@ -948,20 +949,17 @@ def test_break_at_run_boundary(self):
948949 event_1 .summary .value .add (tag = "foo" , simple_value = 1.0 )
949950 event_2 = event_pb2 .Event (step = 2 )
950951 event_2 .summary .value .add (tag = "bar" , simple_value = - 2.0 )
951- run_to_events = collections .OrderedDict (
952- [
953- (long_run_1 , _apply_compat ([event_1 ])),
954- (long_run_2 , _apply_compat ([event_2 ])),
955- ]
956- )
957952
958- builder = _create_request_sender ("123" , mock_client )
959- builder .send_requests (run_to_events )
953+ sender = _create_scalar_request_sender ("123" , mock_client )
954+ self ._add_events (sender , long_run_1 , _apply_compat ([event_1 ]))
955+ self ._add_events (sender , long_run_2 , _apply_compat ([event_2 ]))
956+ sender .flush ()
960957 requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
961958
962959 for request in requests :
963960 _clear_wall_times (request )
964961
962+ # Expect two RPC calls despite a single explicit call to flush().
965963 expected = [
966964 write_service_pb2 .WriteScalarRequest (experiment_id = "123" ),
967965 write_service_pb2 .WriteScalarRequest (experiment_id = "123" ),
@@ -990,14 +988,15 @@ def test_break_at_tag_boundary(self):
990988 event = event_pb2 .Event (step = 1 )
991989 event .summary .value .add (tag = long_tag_1 , simple_value = 1.0 )
992990 event .summary .value .add (tag = long_tag_2 , simple_value = 2.0 )
993- run_to_events = {"train" : _apply_compat ([event ])}
994991
995- builder = _create_request_sender ("123" , mock_client )
996- builder .send_requests (run_to_events )
992+ sender = _create_scalar_request_sender ("123" , mock_client )
993+ self ._add_events (sender , "train" , _apply_compat ([event ]))
994+ sender .flush ()
997995 requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
998996 for request in requests :
999997 _clear_wall_times (request )
1000998
999+ # Expect two RPC calls despite a single explicit call to flush().
10011000 expected = [
10021001 write_service_pb2 .WriteScalarRequest (experiment_id = "123" ),
10031002 write_service_pb2 .WriteScalarRequest (experiment_id = "123" ),
@@ -1030,10 +1029,10 @@ def test_break_at_scalar_point_boundary(self):
10301029 if step > 0 :
10311030 summary .value [0 ].ClearField ("metadata" )
10321031 events .append (event_pb2 .Event (summary = summary , step = step ))
1033- run_to_events = {"train" : _apply_compat (events )}
10341032
1035- builder = _create_request_sender ("123" , mock_client )
1036- builder .send_requests (run_to_events )
1033+ sender = _create_scalar_request_sender ("123" , mock_client )
1034+ self ._add_events (sender , "train" , _apply_compat (events ))
1035+ sender .flush ()
10371036 requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
10381037 for request in requests :
10391038 _clear_wall_times (request )
@@ -1064,12 +1063,6 @@ def test_prunes_tags_and_runs(self):
10641063 event_1 .summary .value .add (tag = "foo" , simple_value = 1.0 )
10651064 event_2 = event_pb2 .Event (step = 2 )
10661065 event_2 .summary .value .add (tag = "bar" , simple_value = - 2.0 )
1067- run_to_events = collections .OrderedDict (
1068- [
1069- ("train" , _apply_compat ([event_1 ])),
1070- ("test" , _apply_compat ([event_2 ])),
1071- ]
1072- )
10731066
10741067 real_create_point = (
10751068 uploader_lib ._ScalarBatchedRequestSender ._create_point
@@ -1090,8 +1083,10 @@ def mock_create_point(uploader_self, *args, **kwargs):
10901083 "_create_point" ,
10911084 mock_create_point ,
10921085 ):
1093- builder = _create_request_sender ("123" , mock_client )
1094- builder .send_requests (run_to_events )
1086+ sender = _create_scalar_request_sender ("123" , mock_client )
1087+ self ._add_events (sender , "train" , _apply_compat ([event_1 ]))
1088+ self ._add_events (sender , "test" , _apply_compat ([event_2 ]))
1089+ sender .flush ()
10951090 requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
10961091 for request in requests :
10971092 _clear_wall_times (request )
@@ -1116,15 +1111,14 @@ def mock_create_point(uploader_self, *args, **kwargs):
11161111
11171112 def test_wall_time_precision (self ):
11181113 # Test a wall time that is exactly representable in float64 but has enough
1119- # digits to incur error if converted to nanonseconds the naive way (* 1e9).
1114+ # digits to incur error if converted to nanoseconds the naive way (* 1e9).
11201115 event1 = event_pb2 .Event (step = 1 , wall_time = 1567808404.765432119 )
11211116 event1 .summary .value .add (tag = "foo" , simple_value = 1.0 )
11221117 # Test a wall time where as a float64, the fractional part on its own will
11231118 # introduce error if truncated to 9 decimal places instead of rounded.
11241119 event2 = event_pb2 .Event (step = 2 , wall_time = 1.000000002 )
11251120 event2 .summary .value .add (tag = "foo" , simple_value = 2.0 )
1126- run_proto = write_service_pb2 .WriteScalarRequest .Run ()
1127- self ._populate_run_from_events (run_proto , [event1 , event2 ])
1121+ run_proto = self ._add_events_and_flush (_apply_compat ([event1 , event2 ]))
11281122 self .assertEqual (
11291123 test_util .timestamp_pb (1567808404765432119 ),
11301124 run_proto .tags [0 ].points [0 ].wall_time ,
0 commit comments