diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py index afa2f4d420..7e2bb206f7 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py @@ -159,6 +159,7 @@ def __init__( self._generator_mutex = threading.Lock() self.purge_orphaned_data = purge_orphaned_data + self._seen_session_start = False self.most_recent_step = -1 self.most_recent_wall_time = -1 @@ -510,11 +511,16 @@ def _MaybePurgeOrphanedData(self, event): def _CheckForRestartAndMaybePurge(self, event): """Check and discard expired events using SessionLog.START. - Check for a SessionLog.START event and purge all previously seen events - with larger steps, because they are out of date. Because of supervisor - threading, it is possible that this logic will cause the first few event - messages to be discarded since supervisor threading does not guarantee - that the START message is deterministically written first. + The first SessionLog.START event in a run indicates the start of a + supervisor session. Subsequent SessionLog.START events indicate a + *restart*, which may need to preempt old events. This method checks + for a session restart event and purges all previously seen events whose + step is larger than or equal to this event's step. + + Because of supervisor threading, it is possible that this logic will + cause the first few event messages to be discarded since supervisor + threading does not guarantee that the START message is deterministically + written first. This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which can inadvertently discard events due to supervisor threading. @@ -523,11 +529,13 @@ def _CheckForRestartAndMaybePurge(self, event): event: The event to use as reference. If the event is a START event, all previously seen events with a greater event.step will be purged. """ - if ( - event.HasField("session_log") - and event.session_log.status == event_pb2.SessionLog.START - ): - self._Purge(event, by_tags=False) + if event.session_log.status != event_pb2.SessionLog.START: + return + if not self._seen_session_start: + # Initial start event: does not indicate a restart. + self._seen_session_start = True + return + self._Purge(event, by_tags=False) def _CheckForOutOfOrderStepAndMaybePurge(self, event): """Check for out-of-order event.step and discard expired events for diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py index eb2ffb6b88..e6178480a4 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py @@ -326,11 +326,14 @@ def testSessionLogStartMessageDiscardsExpiredEvents(self): """ gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) + slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) + gen.AddEvent( event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2") ) gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + gen.AddEvent(event_pb2.Event(wall_time=1, step=100, session_log=slog)) gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) gen.AddScalarTensor("s1", wall_time=1, step=400, value=20) @@ -338,7 +341,6 @@ def testSessionLogStartMessageDiscardsExpiredEvents(self): gen.AddScalarTensor("s2", wall_time=1, step=202, value=20) gen.AddScalarTensor("s2", wall_time=1, step=203, value=20) - slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog)) acc.Reload() self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200])