diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 1a97a82d8a..da3f0677f6 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -444,6 +444,7 @@ def __init__(self, experiment_id, api, rpc_rate_limiter): self._experiment_id = experiment_id self._api = api self._rpc_rate_limiter = rpc_rate_limiter + self._byte_budget_manager = _ByteBudgetManager() # A lower bound on the number of bytes that we may yet add to the # request. self._byte_budget = None # type: int @@ -459,11 +460,8 @@ def _new_request(self): self._request = write_service_pb2.WriteScalarRequest() self._runs.clear() self._tags.clear() - self._byte_budget = _MAX_REQUEST_LENGTH_BYTES self._request.experiment_id = self._experiment_id - self._byte_budget -= self._request.ByteSize() - if self._byte_budget < 0: - raise RuntimeError("Byte budget too small for experiment ID") + self._byte_budget_manager.reset(self._request) def add_event(self, run_name, event, value, metadata): """Attempts to add the given event to the current request. @@ -500,12 +498,7 @@ def flush(self): Starts a new, empty active request. """ request = self._request - for (run_idx, run) in reversed(list(enumerate(request.runs))): - for (tag_idx, tag) in reversed(list(enumerate(run.tags))): - if not tag.points: - del run.tags[tag_idx] - if not run.tags: - del request.runs[run_idx] + _prune_empty_tags_and_runs(request) if not request.runs: return @@ -536,12 +529,7 @@ def _create_run(self, run_name): request budget. """ run_proto = self._request.runs.add(name=run_name) - # We can't calculate the proto key cost exactly ahead of time, as - # it depends on the total size of all tags. Be conservative. - cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1 - if cost > self._byte_budget: - raise _OutOfSpaceError() - self._byte_budget -= cost + self._byte_budget_manager.add_run(run_proto) return run_proto def _create_tag(self, run_proto, tag_name, metadata): @@ -562,13 +550,7 @@ def _create_tag(self, run_proto, tag_name, metadata): """ tag_proto = run_proto.tags.add(name=tag_name) tag_proto.metadata.CopyFrom(metadata) - submessage_cost = tag_proto.ByteSize() - # We can't calculate the proto key cost exactly ahead of time, as - # it depends on the number of points. Be conservative. - cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1 - if cost > self._byte_budget: - raise _OutOfSpaceError() - self._byte_budget -= cost + self._byte_budget_manager.add_tag(tag_proto) return tag_proto def _create_point(self, tag_proto, event, value): @@ -591,13 +573,124 @@ def _create_point(self, tag_proto, event, value): # TODO(@nfelt): skip tensor roundtrip for Value with simple_value set point.value = tensor_util.make_ndarray(value.tensor).item() util.set_timestamp(point.wall_time, event.wall_time) - submessage_cost = point.ByteSize() - cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key - if cost > self._byte_budget: + try: + self._byte_budget_manager.add_point(point) + except _OutOfSpaceError as e: tag_proto.points.pop() + raise e + return point + + +class _ByteBudgetManager(object): + """Helper class for managing the request byte budget for certain RPCs. + + This should be used for RPCs that organize data by Runs, Tags, and Points, + specifically WriteScalar and WriteTensor. + + Any call to add_run(), add_tag(), or add_point() may raise an + _OutOfSpaceError, which is non-fatal. It signals to the caller that they + should flush the current request and begin a new one. + + For more information on the protocol buffer encoding and how byte cost + can be calculated, visit: + + https://developers.google.com/protocol-buffers/docs/encoding + """ + + def __init__(self): + # The remaining number of bytes that we may yet add to the request. + self._byte_budget = None # type: int + + def reset(self, base_request): + """Resets the byte budget and calculates the cost of the base request. + + Args: + base_request: Base request. + + Raises: + _OutOfSpaceError: If the size of the request exceeds the entire + request byte budget. + """ + self._byte_budget = _MAX_REQUEST_LENGTH_BYTES + self._byte_budget -= base_request.ByteSize() + if self._byte_budget < 0: + raise RuntimeError("Byte budget too small for base request") + + def add_run(self, run_proto): + """Integrates the cost of a run proto into the byte budget. + + Args: + run_proto: The proto representing a run. + + Raises: + _OutOfSpaceError: If adding the run would exceed the remaining request + budget. + """ + cost = ( + # The size of the run proto without any tag fields set. + run_proto.ByteSize() + # The size of the varint that describes the length of the run + # proto. We can't yet know the final size of the run proto -- we + # haven't yet set any tag or point values -- so we can't know the + # final size of this length varint. We conservatively assume it is + # maximum size. + + _MAX_VARINT64_LENGTH_BYTES + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + def add_tag(self, tag_proto): + """Integrates the cost of a tag proto into the byte budget. + + Args: + tag_proto: The proto representing a tag. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining request + budget. + """ + cost = ( + # The size of the tag proto without any tag fields set. + tag_proto.ByteSize() + # The size of the varint that describes the length of the tag + # proto. We can't yet know the final size of the tag proto -- we + # haven't yet set any point values -- so we can't know the final + # size of this length varint. We conservatively assume it is maximum + # size. + + _MAX_VARINT64_LENGTH_BYTES + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + def add_point(self, point_proto): + """Integrates the cost of a point proto into the byte budget. + + Args: + point_proto: The proto representing a point. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining request + budget. + """ + submessage_cost = point_proto.ByteSize() + cost = ( + # The size of the point proto. + submessage_cost + # The size of the varint that describes the length of the point + # proto. + + _varint_cost(submessage_cost) + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: raise _OutOfSpaceError() self._byte_budget -= cost - return point class _BlobRequestSender(object): @@ -828,6 +921,15 @@ def _varint_cost(n): return result +def _prune_empty_tags_and_runs(request): + for (run_idx, run) in reversed(list(enumerate(request.runs))): + for (tag_idx, tag) in reversed(list(enumerate(run.tags))): + if not tag.points: + del run.tags[tag_idx] + if not run.tags: + del request.runs[run_idx] + + def _filter_graph_defs(event): for v in event.summary.value: if v.metadata.plugin_data.plugin_name != graphs_metadata.PLUGIN_NAME: diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index 1c73945e59..88e6248df5 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -918,7 +918,7 @@ def test_propagates_experiment_deletion(self): with self.assertRaises(uploader_lib.ExperimentNotFoundError): sender.flush() - def test_no_budget_for_experiment_id(self): + def test_no_budget_for_base_request(self): mock_client = _create_mock_client() long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES with self.assertRaises(RuntimeError) as cm: @@ -926,7 +926,7 @@ def test_no_budget_for_experiment_id(self): experiment_id=long_experiment_id, api=mock_client, ) self.assertEqual( - str(cm.exception), "Byte budget too small for experiment ID" + str(cm.exception), "Byte budget too small for base request" ) def test_no_room_for_single_point(self): @@ -1039,6 +1039,11 @@ def test_break_at_scalar_point_boundary(self): self.assertGreater(len(requests), 1) self.assertLess(len(requests), point_count) + # This is the observed number of requests when running the test. There + # is no reasonable way to derive this value from just reading the code. + # The number of requests does not have to be 33 to be correct but if it + # changes it probably warrants some investigation or thought. + self.assertEqual(33, len(requests)) total_points_in_result = 0 for request in requests: @@ -1064,24 +1069,17 @@ def test_prunes_tags_and_runs(self): event_2 = event_pb2.Event(step=2) event_2.summary.value.add(tag="bar", simple_value=-2.0) - real_create_point = ( - uploader_lib._ScalarBatchedRequestSender._create_point - ) - - create_point_call_count_box = [0] + add_point_call_count_box = [0] - def mock_create_point(uploader_self, *args, **kwargs): + def mock_add_point(byte_budget_manager_self, point): # Simulate out-of-space error the first time that we try to store # the second point. - create_point_call_count_box[0] += 1 - if create_point_call_count_box[0] == 2: + add_point_call_count_box[0] += 1 + if add_point_call_count_box[0] == 2: raise uploader_lib._OutOfSpaceError() - return real_create_point(uploader_self, *args, **kwargs) with mock.patch.object( - uploader_lib._ScalarBatchedRequestSender, - "_create_point", - mock_create_point, + uploader_lib._ByteBudgetManager, "add_point", mock_add_point, ): sender = _create_scalar_request_sender("123", mock_client) self._add_events(sender, "train", _apply_compat([event_1]))