Skip to content

Commit 73b4df9

Browse files
authored
Ensure EventFileLoader only uses no-TF stub when required (#3194)
* add notf test for event_file_loader_test.py * implement truncation recovery in stub PyRecordReader * Check explicitly for no-TF case in make_tf_record_iterator * remove unused tensorboard.compat._pywrap_tensorflow import in projector plugin * remove now unused tensorboard.compat._pywrap_tensorflow lazy-loader
1 parent 3b8320c commit 73b4df9

File tree

5 files changed

+90
-56
lines changed

5 files changed

+90
-56
lines changed

tensorboard/backend/event_processing/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,22 @@ py_test(
156156
],
157157
)
158158

159+
py_test(
160+
name = "event_file_loader_notf_test",
161+
size = "small",
162+
srcs = ["event_file_loader_test.py"],
163+
main = "event_file_loader_test.py",
164+
srcs_version = "PY2AND3",
165+
deps = [
166+
":event_file_loader",
167+
"//tensorboard:expect_tensorflow_installed",
168+
"//tensorboard/compat:no_tensorflow",
169+
"//tensorboard/compat/proto:protos_all_py_pb2",
170+
"//tensorboard/summary/writer",
171+
"@org_pythonhosted_six",
172+
],
173+
)
174+
159175
py_library(
160176
name = "event_accumulator",
161177
srcs = [

tensorboard/backend/event_processing/event_file_loader.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,29 @@
2929

3030
def _make_tf_record_iterator(file_path):
3131
"""Returns an iterator over TF records for the given tfrecord file."""
32-
try:
33-
from tensorboard.compat import _pywrap_tensorflow
34-
35-
py_record_reader_new = _pywrap_tensorflow.PyRecordReader_New
36-
except (ImportError, AttributeError):
37-
py_record_reader_new = None
32+
# If we don't have TF at all, use the stub implementation.
33+
if tf.__version__ == "stub":
34+
# TODO(#1711): Reshape stub implementation to fit tf_record_iterator API
35+
# rather than needlessly emulating the old PyRecordReader_New API.
36+
logger.debug("Opening a stub record reader pointing at %s", file_path)
37+
return _PyRecordReaderIterator(
38+
tf.pywrap_tensorflow.PyRecordReader_New, file_path
39+
)
3840
# If PyRecordReader exists, use it, otherwise use tf_record_iterator().
3941
# Check old first, then new, since tf_record_iterator existed previously but
4042
# only gained the semantics we need at the time PyRecordReader was removed.
4143
#
4244
# TODO(#1711): Eventually remove PyRecordReader fallback once we can drop
4345
# support for TF 2.1 and prior, and find a non-deprecated replacement for
4446
# tf.compat.v1.io.tf_record_iterator.
47+
try:
48+
from tensorflow.python import pywrap_tensorflow
49+
50+
py_record_reader_new = pywrap_tensorflow.PyRecordReader_New
51+
except (ImportError, AttributeError):
52+
py_record_reader_new = None
4553
if py_record_reader_new:
46-
logger.debug("Opening a record reader pointing at %s", file_path)
54+
logger.debug("Opening a PyRecordReader pointing at %s", file_path)
4755
return _PyRecordReaderIterator(py_record_reader_new, file_path)
4856
else:
4957
logger.debug("Opening a tf_record_iterator pointing at %s", file_path)

tensorboard/compat/__init__.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,36 +74,3 @@ def tf2():
7474
# As a fallback, try `tensorflow.compat.v2` if it's defined.
7575
return tf.compat.v2
7676
raise ImportError("cannot import tensorflow 2.0 API")
77-
78-
79-
# TODO(https://github.com/tensorflow/tensorboard/issues/1711): remove this
80-
@_lazy.lazy_load("tensorboard.compat._pywrap_tensorflow")
81-
def _pywrap_tensorflow():
82-
"""Provide pywrap_tensorflow access in TensorBoard.
83-
84-
pywrap_tensorflow cannot be accessed from tf.python.pywrap_tensorflow
85-
and needs to be imported using
86-
`from tensorflow.python import pywrap_tensorflow`. Therefore, we provide
87-
a separate accessor function for it here.
88-
89-
NOTE: pywrap_tensorflow is not part of TensorFlow API and this
90-
dependency will go away soon.
91-
92-
Returns:
93-
pywrap_tensorflow import, if available.
94-
95-
Raises:
96-
ImportError: if we couldn't import pywrap_tensorflow.
97-
"""
98-
try:
99-
from tensorboard.compat import notf
100-
except ImportError:
101-
try:
102-
from tensorflow.python import pywrap_tensorflow
103-
104-
return pywrap_tensorflow
105-
except ImportError:
106-
pass
107-
from tensorboard.compat.tensorflow_stub import pywrap_tensorflow
108-
109-
return pywrap_tensorflow

tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,29 @@ def __init__(
197197
self.status = status
198198
self.curr_event = None
199199
self.file_handle = gfile.GFile(self.filename, "rb")
200+
# Maintain a buffer of partially read records, so we can recover from
201+
# truncated records upon a retry.
202+
self._buffer = b""
203+
self._buffer_pos = 0
200204

201205
def GetNext(self):
206+
# Each new read should start at the beginning of any partial record.
207+
self._buffer_pos = 0
202208
# Read the header
203209
self.curr_event = None
204-
header_str = self.file_handle.read(8)
205-
if len(header_str) != 8:
210+
header_str = self._read(8)
211+
if not header_str:
206212
# Hit EOF so raise and exit
207213
raise errors.OutOfRangeError(None, None, "No more events to read")
214+
if len(header_str) < 8:
215+
raise self._truncation_error("header")
208216
header = struct.unpack("Q", header_str)
209217

210218
# Read the crc32, which is 4 bytes, and check it against
211219
# the crc32 of the header
212-
crc_header_str = self.file_handle.read(4)
220+
crc_header_str = self._read(4)
221+
if len(crc_header_str) < 4:
222+
raise self._truncation_error("header crc")
213223
crc_header = struct.unpack("I", crc_header_str)
214224
header_crc_calc = masked_crc32c(header_str)
215225
if header_crc_calc != crc_header[0]:
@@ -220,25 +230,59 @@ def GetNext(self):
220230
# The length of the header tells us how many bytes the Event
221231
# string takes
222232
header_len = int(header[0])
223-
event_str = self.file_handle.read(header_len)
233+
event_str = self._read(header_len)
234+
if len(event_str) < header_len:
235+
raise self._truncation_error("data")
224236

225237
event_crc_calc = masked_crc32c(event_str)
226238

227239
# The next 4 bytes contain the crc32 of the Event string,
228-
# which we check for integrity. Sometimes, the last Event
229-
# has no crc32, in which case we skip.
230-
crc_event_str = self.file_handle.read(4)
231-
if crc_event_str:
232-
crc_event = struct.unpack("I", crc_event_str)
233-
if event_crc_calc != crc_event[0]:
234-
raise errors.DataLossError(
235-
None,
236-
None,
237-
"{} failed event crc32 check".format(self.filename),
238-
)
240+
# which we check for integrity.
241+
crc_event_str = self._read(4)
242+
if len(crc_event_str) < 4:
243+
raise self._truncation_error("data crc")
244+
crc_event = struct.unpack("I", crc_event_str)
245+
if event_crc_calc != crc_event[0]:
246+
raise errors.DataLossError(
247+
None, None, "{} failed event crc32 check".format(self.filename),
248+
)
239249

240250
# Set the current event to be read later by record() call
241251
self.curr_event = event_str
252+
# Clear the buffered partial record since we're done reading it.
253+
self._buffer = b""
254+
255+
def _read(self, n):
256+
"""Read up to n bytes from the underlying file, with buffering.
257+
258+
Reads are satisfied from a buffer of previous data read starting at
259+
`self._buffer_pos` until the buffer is exhausted, and then from the
260+
actual underlying file. Any new data is added to the buffer, and
261+
`self._buffer_pos` is advanced to the point in the buffer past all
262+
data returned as part of this read.
263+
264+
Args:
265+
n: non-negative number of bytes to read
266+
267+
Returns:
268+
bytestring of data read, up to n bytes
269+
"""
270+
result = self._buffer[self._buffer_pos : self._buffer_pos + n]
271+
self._buffer_pos += len(result)
272+
n -= len(result)
273+
if n > 0:
274+
new_data = self.file_handle.read(n)
275+
result += new_data
276+
self._buffer += new_data
277+
self._buffer_pos += len(new_data)
278+
return result
279+
280+
def _truncation_error(self, section):
281+
return errors.DataLossError(
282+
None,
283+
None,
284+
"{} has truncated record in {}".format(self.filename, section),
285+
)
242286

243287
def record(self):
244288
return self.curr_event

tensorboard/plugins/projector/projector_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
from tensorboard.backend.http_util import Respond
3636
from tensorboard.compat import tf
37-
from tensorboard.compat import _pywrap_tensorflow
3837
from tensorboard.plugins import base_plugin
3938
from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
4039
from tensorboard.util import tb_logging

0 commit comments

Comments
 (0)