Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 129 additions & 27 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(self, experiment_id, api, rpc_rate_limiter):
self._experiment_id = experiment_id
self._api = api
self._rpc_rate_limiter = rpc_rate_limiter
self._byte_budget_manager = _ByteBudgetManager()
# A lower bound on the number of bytes that we may yet add to the
# request.
self._byte_budget = None # type: int
Expand All @@ -459,11 +460,8 @@ def _new_request(self):
self._request = write_service_pb2.WriteScalarRequest()
self._runs.clear()
self._tags.clear()
self._byte_budget = _MAX_REQUEST_LENGTH_BYTES
self._request.experiment_id = self._experiment_id
self._byte_budget -= self._request.ByteSize()
if self._byte_budget < 0:
raise RuntimeError("Byte budget too small for experiment ID")
self._byte_budget_manager.reset(self._request)

def add_event(self, run_name, event, value, metadata):
"""Attempts to add the given event to the current request.
Expand Down Expand Up @@ -500,12 +498,7 @@ def flush(self):
Starts a new, empty active request.
"""
request = self._request
for (run_idx, run) in reversed(list(enumerate(request.runs))):
for (tag_idx, tag) in reversed(list(enumerate(run.tags))):
if not tag.points:
del run.tags[tag_idx]
if not run.tags:
del request.runs[run_idx]
_prune_empty_tags_and_runs(request)
if not request.runs:
return

Expand Down Expand Up @@ -536,12 +529,7 @@ def _create_run(self, run_name):
request budget.
"""
run_proto = self._request.runs.add(name=run_name)
# We can't calculate the proto key cost exactly ahead of time, as
# it depends on the total size of all tags. Be conservative.
cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1
if cost > self._byte_budget:
raise _OutOfSpaceError()
self._byte_budget -= cost
self._byte_budget_manager.add_run(run_proto)
return run_proto

def _create_tag(self, run_proto, tag_name, metadata):
Expand All @@ -562,13 +550,7 @@ def _create_tag(self, run_proto, tag_name, metadata):
"""
tag_proto = run_proto.tags.add(name=tag_name)
tag_proto.metadata.CopyFrom(metadata)
submessage_cost = tag_proto.ByteSize()
# We can't calculate the proto key cost exactly ahead of time, as
# it depends on the number of points. Be conservative.
cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1
if cost > self._byte_budget:
raise _OutOfSpaceError()
self._byte_budget -= cost
self._byte_budget_manager.add_tag(tag_proto)
return tag_proto

def _create_point(self, tag_proto, event, value):
Expand All @@ -591,13 +573,124 @@ def _create_point(self, tag_proto, event, value):
# TODO(@nfelt): skip tensor roundtrip for Value with simple_value set
point.value = tensor_util.make_ndarray(value.tensor).item()
util.set_timestamp(point.wall_time, event.wall_time)
submessage_cost = point.ByteSize()
cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key
if cost > self._byte_budget:
try:
self._byte_budget_manager.add_point(point)
except _OutOfSpaceError as e:
tag_proto.points.pop()
raise e
return point


class _ByteBudgetManager(object):
"""Helper class for managing the request byte budget for certain RPCs.

This should be used for RPCs that organize data by Runs, Tags, and Points,
specifically WriteScalar and WriteTensor.

Any call to add_run(), add_tag(), or add_point() may raise an
_OutOfSpaceError, which is non-fatal. It signals to the caller that they
should flush the current request and begin a new one.

For more information on the protocol buffer encoding and how byte cost
can be calculated, visit:

https://developers.google.com/protocol-buffers/docs/encoding
"""

def __init__(self):
# The remaining number of bytes that we may yet add to the request.
self._byte_budget = None # type: int

def reset(self, base_request):
"""Resets the byte budget and calculates the cost of the base request.

Args:
base_request: Base request.

Raises:
_OutOfSpaceError: If the size of the request exceeds the entire
request byte budget.
"""
self._byte_budget = _MAX_REQUEST_LENGTH_BYTES
self._byte_budget -= base_request.ByteSize()
if self._byte_budget < 0:
raise RuntimeError("Byte budget too small for base request")

def add_run(self, run_proto):
"""Integrates the cost of a run proto into the byte budget.

Args:
run_proto: The proto representing a run.

Raises:
_OutOfSpaceError: If adding the run would exceed the remaining request
budget.
"""
cost = (
# The size of the run proto without any tag fields set.
run_proto.ByteSize()
# The size of the varint that describes the length of the run
# proto. We can't yet know the final size of the run proto -- we
# haven't yet set any tag or point values -- so we can't know the
# final size of this length varint. We conservatively assume it is
# maximum size.
+ _MAX_VARINT64_LENGTH_BYTES
# The size of the proto key.
+ 1
)
if cost > self._byte_budget:
raise _OutOfSpaceError()
self._byte_budget -= cost

def add_tag(self, tag_proto):
"""Integrates the cost of a tag proto into the byte budget.

Args:
tag_proto: The proto representing a tag.

Raises:
_OutOfSpaceError: If adding the tag would exceed the remaining request
budget.
"""
cost = (
# The size of the tag proto without any tag fields set.
tag_proto.ByteSize()
# The size of the varint that describes the length of the tag
# proto. We can't yet know the final size of the tag proto -- we
# haven't yet set any point values -- so we can't know the final
# size of this length varint. We conservatively assume it is maximum
# size.
+ _MAX_VARINT64_LENGTH_BYTES
# The size of the proto key.
+ 1
)
if cost > self._byte_budget:
raise _OutOfSpaceError()
self._byte_budget -= cost

def add_point(self, point_proto):
"""Integrates the cost of a point proto into the byte budget.

Args:
point_proto: The proto representing a point.

Raises:
_OutOfSpaceError: If adding the point would exceed the remaining request
budget.
"""
submessage_cost = point_proto.ByteSize()
cost = (
# The size of the point proto.
submessage_cost
# The size of the varint that describes the length of the point
# proto.
+ _varint_cost(submessage_cost)
# The size of the proto key.
+ 1
)
if cost > self._byte_budget:
raise _OutOfSpaceError()
self._byte_budget -= cost
return point


class _BlobRequestSender(object):
Expand Down Expand Up @@ -828,6 +921,15 @@ def _varint_cost(n):
return result


def _prune_empty_tags_and_runs(request):
for (run_idx, run) in reversed(list(enumerate(request.runs))):
for (tag_idx, tag) in reversed(list(enumerate(run.tags))):
if not tag.points:
del run.tags[tag_idx]
if not run.tags:
del request.runs[run_idx]


def _filter_graph_defs(event):
for v in event.summary.value:
if v.metadata.plugin_data.plugin_name != graphs_metadata.PLUGIN_NAME:
Expand Down
26 changes: 12 additions & 14 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,15 @@ def test_propagates_experiment_deletion(self):
with self.assertRaises(uploader_lib.ExperimentNotFoundError):
sender.flush()

def test_no_budget_for_experiment_id(self):
def test_no_budget_for_base_request(self):
mock_client = _create_mock_client()
long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
with self.assertRaises(RuntimeError) as cm:
_create_scalar_request_sender(
experiment_id=long_experiment_id, api=mock_client,
)
self.assertEqual(
str(cm.exception), "Byte budget too small for experiment ID"
str(cm.exception), "Byte budget too small for base request"
)

def test_no_room_for_single_point(self):
Expand Down Expand Up @@ -1039,6 +1039,11 @@ def test_break_at_scalar_point_boundary(self):

self.assertGreater(len(requests), 1)
self.assertLess(len(requests), point_count)
# This is the observed number of requests when running the test. There
# is no reasonable way to derive this value from just reading the code.
# The number of requests does not have to be 33 to be correct but if it
# changes it probably warrants some investigation or thought.
self.assertEqual(33, len(requests))

total_points_in_result = 0
for request in requests:
Expand All @@ -1064,24 +1069,17 @@ def test_prunes_tags_and_runs(self):
event_2 = event_pb2.Event(step=2)
event_2.summary.value.add(tag="bar", simple_value=-2.0)

real_create_point = (
uploader_lib._ScalarBatchedRequestSender._create_point
)

create_point_call_count_box = [0]
add_point_call_count_box = [0]

def mock_create_point(uploader_self, *args, **kwargs):
def mock_add_point(byte_budget_manager_self, point):
# Simulate out-of-space error the first time that we try to store
# the second point.
create_point_call_count_box[0] += 1
if create_point_call_count_box[0] == 2:
add_point_call_count_box[0] += 1
if add_point_call_count_box[0] == 2:
raise uploader_lib._OutOfSpaceError()
return real_create_point(uploader_self, *args, **kwargs)

with mock.patch.object(
uploader_lib._ScalarBatchedRequestSender,
"_create_point",
mock_create_point,
uploader_lib._ByteBudgetManager, "add_point", mock_add_point,
):
sender = _create_scalar_request_sender("123", mock_client)
self._add_events(sender, "train", _apply_compat([event_1]))
Expand Down