Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import os
from unittest import mock

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -135,33 +136,28 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def setUp(self):
super(MockingEventAccumulatorTest, self).setUp()
self.stubs = tf.compat.v1.test.StubOutForTesting()
self._real_constructor = ea.EventAccumulator
self._real_generator = ea._GeneratorFromPath

def _FakeAccumulatorConstructor(generator, *args, **kwargs):
def _FakeGeneratorFromPath(path, event_file_active_filter=None):
return generator

ea._GeneratorFromPath = _FakeGeneratorFromPath
return self._real_constructor(generator, *args, **kwargs)

ea.EventAccumulator = _FakeAccumulatorConstructor

def tearDown(self):
super(MockingEventAccumulatorTest, self).tearDown()
self.stubs.CleanUp()
ea.EventAccumulator = self._real_constructor
ea._GeneratorFromPath = self._real_generator

def _make_accumulator(self, generator, **kwargs):
patcher = mock.patch.object(ea, "_GeneratorFromPath", autospec=True)
mock_impl = patcher.start()
mock_impl.return_value = generator
self.addCleanup(patcher.stop)
return ea.EventAccumulator("path/is/ignored", **kwargs)

def testEmptyAccumulator(self):
gen = _EventGenerator(self)
x = ea.EventAccumulator(gen)
x = self._make_accumulator(gen)
x.Reload()
self.assertTagsEqual(x.Tags(), {})

def testReload(self):
"""EventAccumulator contains suitable tags after calling Reload."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
acc.Reload()
self.assertTagsEqual(acc.Tags(), {})
gen.AddScalarTensor("s1", wall_time=1, step=10, value=50)
Expand All @@ -177,15 +173,15 @@ def testReload(self):
def testKeyError(self):
"""KeyError should be raised when accessing non-existing keys."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
acc.Reload()
with self.assertRaises(KeyError):
acc.Tensors("s1")

def testNonValueEvents(self):
"""Non-value events in the generator don't cause early exits."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
gen.AddScalarTensor("s1", wall_time=1, step=10, value=20)
gen.AddEvent(
event_pb2.Event(wall_time=2, step=20, file_version="nots2")
Expand Down Expand Up @@ -214,7 +210,7 @@ def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
self.stubs.Set(logger, "warning", warnings.append)

gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)

gen.AddEvent(
event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
Expand All @@ -239,7 +235,7 @@ def testOrphanedDataNotDiscardedIfFlagUnset(self):
"""Tests that events are not discarded if purge_orphaned_data is
false."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
acc = self._make_accumulator(gen, purge_orphaned_data=False)

gen.AddEvent(
event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
Expand Down Expand Up @@ -275,7 +271,7 @@ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
self.stubs.Set(logger, "warning", warnings.append)

gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)

gen.AddEvent(
event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
Expand Down Expand Up @@ -306,7 +302,7 @@ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
def testOnlySummaryEventsTriggerDiscards(self):
"""Test that file version event does not trigger data purge."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
ev1 = event_pb2.Event(wall_time=2, step=0, file_version="brain.Event:1")
graph_bytes = tf.compat.v1.GraphDef().SerializeToString()
Expand All @@ -325,7 +321,7 @@ def testSessionLogStartMessageDiscardsExpiredEvents(self):
event.proto for file_version >= brain.Event:2.
"""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START)

gen.AddEvent(
Expand All @@ -350,7 +346,7 @@ def testFirstEventTimestamp(self):
"""Test that FirstEventTimestamp() returns wall_time of the first
event."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
gen.AddEvent(
event_pb2.Event(wall_time=10, step=20, file_version="brain.Event:2")
)
Expand All @@ -360,7 +356,7 @@ def testFirstEventTimestamp(self):
def testReloadPopulatesFirstEventTimestamp(self):
"""Test that Reload() means FirstEventTimestamp() won't load events."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
gen.AddEvent(
event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
)
Expand All @@ -376,7 +372,7 @@ def _Die(*args, **kwargs): # pylint: disable=unused-argument
def testFirstEventTimestampLoadsEvent(self):
"""Test that FirstEventTimestamp() doesn't discard the loaded event."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc = self._make_accumulator(gen)
gen.AddEvent(
event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
)
Expand All @@ -403,7 +399,7 @@ def testNewStyleScalarSummary(self):
summ = sess.run(merged, feed_dict={step: float(i)})
writer.add_summary(summ, global_step=i)

accumulator = ea.EventAccumulator(event_sink)
accumulator = self._make_accumulator(event_sink)
accumulator.Reload()

tags = [
Expand Down Expand Up @@ -451,7 +447,7 @@ def testNewStyleAudioSummary(self):
summ = sess.run(merged)
writer.add_summary(summ, global_step=i)

accumulator = ea.EventAccumulator(event_sink)
accumulator = self._make_accumulator(event_sink)
accumulator.Reload()

tags = [
Expand Down Expand Up @@ -498,7 +494,7 @@ def testNewStyleImageSummary(self):
summ = sess.run(merged)
writer.add_summary(summ, global_step=i)

accumulator = ea.EventAccumulator(event_sink)
accumulator = self._make_accumulator(event_sink)
accumulator.Reload()

tags = [
Expand Down Expand Up @@ -536,7 +532,7 @@ def testTFSummaryTensor(self):
summ = sess.run(merged)
writer.add_summary(summ, 0)

accumulator = ea.EventAccumulator(event_sink)
accumulator = self._make_accumulator(event_sink)
accumulator.Reload()

self.assertTagsEqual(
Expand Down Expand Up @@ -581,7 +577,7 @@ def _testTFSummaryTensor_SizeGuidance(
for step in range(steps):
writer.add_summary(sess.run(merged), global_step=step)

accumulator = ea.EventAccumulator(
accumulator = self._make_accumulator(
event_sink, tensor_size_guidance=tensor_size_guidance
)
accumulator.Reload()
Expand Down