Skip to content

Commit 767360e

Browse files
authored
Make tensorboard.compat.{tf,tf2} lazily loaded (#1781)
* Make tensorboard.compat.{tf,tf2} lazily loaded * add summary_dep_test.py to check for no TF dep from tb.summary.v2 * CR: rewrite lazy_load() to use closure and guaranteed single initialization * add a couple extra checks to SummaryV2DepTest * CR: avoid leaking _importlib symbol * CR: clarify memoization contract, avoid deadlock, nicer repr * CR: no moar hasattr, typo fix
1 parent 6c67c02 commit 767360e

File tree

18 files changed

+214
-135
lines changed

18 files changed

+214
-135
lines changed

tensorboard/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121

2222
from tensorboard import lazy
2323

24-
pkg = lambda i: i # helps google sync process
25-
mod = lambda i: lazy.LazyLoader(i[i.rindex('.') + 1:], globals(), i) # noqa: F821
2624

27-
program = mod(pkg('tensorboard.program'))
28-
summary = mod(pkg('tensorboard.summary'))
25+
@lazy.lazy_load('tensorboard.program')
26+
def program():
27+
import tensorboard.program as module # pylint: disable=g-import-not-at-top
28+
return module
2929

30-
del lazy
31-
del mod
32-
del pkg
30+
31+
@lazy.lazy_load('tensorboard.summary')
32+
def summary():
33+
import tensorboard.summary as module # pylint: disable=g-import-not-at-top
34+
return module

tensorboard/compat/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ py_library(
2323
srcs = ["__init__.py"],
2424
srcs_version = "PY2AND3",
2525
visibility = ["//visibility:public"],
26+
deps = [
27+
"//tensorboard:lazy",
28+
],
2629
)
2730

2831
# This rule ensures that `from tensorboard.compat import tf` will provide a

tensorboard/compat/__init__.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,59 +14,59 @@
1414

1515
"""Compatibility interfaces for TensorBoard.
1616
17-
This module provides logic for importing variations on the TensorFlow APIs.
18-
19-
The alias `tf` is for the main TF API used by TensorBoard. By default this will
20-
be the result of `import tensorflow as tf`, or undefined if that fails. This
21-
can be used in combination with //tensorboard/compat:tensorflow (to fall back to
22-
a stub TF API implementation if the real one is not available) and
23-
//tensorboard/compat:no_tensorflow (to use the stub TF API unconditionally).
24-
25-
The function `import_tf_v2` provides common logic for importing the TF 2.0 API,
26-
and returns the root module of the API if found, or else raises ImportError.
27-
This is a function instead of a direct alias like `tf` in order to provide
28-
enough indirection to get around circular dependencies.
17+
This module provides logic for importing variations on the TensorFlow APIs, as
18+
lazily loaded imports to help avoid circular dependency issues and defer the
19+
search and loading of the module until necessary.
2920
"""
3021

3122
from __future__ import absolute_import
3223
from __future__ import division
3324
from __future__ import print_function
3425

35-
# First, check if using TF is explicitly disabled by request.
36-
USING_TF = True
37-
try:
38-
from tensorboard.compat import notf
39-
USING_TF = False
40-
except ImportError:
41-
pass
26+
import importlib as _importlib
27+
28+
import tensorboard.lazy as _lazy
4229

43-
# If TF is not disabled, check if it's available.
44-
if USING_TF:
45-
try:
46-
import tensorflow as tf
47-
except ImportError:
48-
USING_TF = False
4930

50-
if not USING_TF:
51-
# If we can't use TF, try to provide the stub instead.
52-
# This will only work if the tensorflow_stub dep is included
53-
# in the build, via the `tensorboard/compat:tensorflow` target.
31+
@_lazy.lazy_load('tensorboard.compat.tf')
32+
def tf():
33+
"""Provide the root module of a TF-like API for use within TensorBoard.
34+
35+
By default this is equivalent to `import tensorflow as tf`, but it can be used
36+
in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF
37+
API implementation if the real one is not available) or with
38+
//tensorboard/compat:no_tensorflow (to force unconditional use of the stub).
39+
40+
Returns:
41+
The root module of a TF-like API, if available.
42+
43+
Raises:
44+
ImportError: if a TF-like API is not available.
45+
"""
5446
try:
55-
from tensorboard.compat import tensorflow_stub as tf
47+
_importlib.import_module('tensorboard.compat.notf')
5648
except ImportError:
57-
pass
49+
try:
50+
return _importlib.import_module('tensorflow')
51+
except ImportError:
52+
pass
53+
return _importlib.import_module('tensorboard.compat.tensorflow_stub') # pylint: disable=line-too-long
54+
55+
56+
@_lazy.lazy_load('tensorboard.compat.tf2')
57+
def tf2():
58+
"""Provide the root module of a TF-2.0 API for use within TensorBoard.
5859
60+
Returns:
61+
The root module of a TF-2.0 API, if available.
5962
60-
def import_tf_v2():
61-
"""Import the TF 2.0 API if possible, or raise an ImportError."""
62-
# We must be able to use TF in order to provide the TF 2.0 API.
63-
if USING_TF:
64-
# Check if this is TF 2.0 by looking for a known 2.0-only tf.summary symbol.
65-
# TODO(nickfelt): determine a cleaner way to do this.
66-
if hasattr(tf, 'summary') and hasattr(tf.summary, 'write'):
67-
return tf
68-
else:
69-
# As a fallback, try `tensorflow.compat.v2` if it's defined.
70-
if hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'):
71-
return tf.compat.v2
63+
Raises:
64+
ImportError: if a TF-2.0 API is not available.
65+
"""
66+
# Import the `tf` compat API from this file and check if it's already TF 2.0.
67+
if tf.__version__.startswith('2.'):
68+
return tf
69+
elif hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'):
70+
# As a fallback, try `tensorflow.compat.v2` if it's defined.
71+
return tf.compat.v2
7272
raise ImportError('cannot import tensorflow 2.0 API')

tensorboard/compat/tensorflow_stub/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@
3434
from . import gfile # noqa
3535
from . import pywrap_tensorflow # noqa
3636
from . import tensor_shape # noqa
37+
38+
# Set a fake __version__ to help distinguish this as our own stub API.
39+
__version__ = 'stub'

tensorboard/lazy.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,68 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import importlib
22+
import functools
23+
import threading
2324
import types
2425

2526

26-
class LazyLoader(types.ModuleType):
27-
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
27+
def lazy_load(name):
28+
"""Decorator to define a function that lazily loads the module 'name'.
2829
29-
# The lint error here is incorrect.
30-
def __init__(self, local_name, parent_module_globals, name): # pylint: disable=super-on-old-class
31-
self._local_name = local_name
32-
self._parent_module_globals = parent_module_globals
30+
This can be used to defer importing troublesome dependencies - e.g. ones that
31+
are large and infrequently used, or that cause a dependency cycle -
32+
until they are actually used.
3333
34-
super(LazyLoader, self).__init__(name)
34+
Args:
35+
name: the fully-qualified name of the module; typically the last segment
36+
of 'name' matches the name of the decorated function
3537
36-
def _load(self):
37-
# Import the target module and insert it into the parent's namespace
38-
module = importlib.import_module(self.__name__)
39-
self._parent_module_globals[self._local_name] = module
38+
Returns:
39+
Decorator function that produces a lazy-loading module 'name' backed by the
40+
underlying decorated function.
41+
"""
42+
def wrapper(load_fn):
43+
# Wrap load_fn to call it exactly once and update __dict__ afterwards to
44+
# make future lookups efficient (only failed lookups call __getattr__).
45+
@_memoize
46+
def load_once(self):
47+
module = load_fn()
48+
self.__dict__.update(module.__dict__)
49+
load_once.loaded = True
50+
return module
51+
load_once.loaded = False
4052

41-
# Update this object's dict so that if someone keeps a reference to the
42-
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
43-
# that fail).
44-
self.__dict__.update(module.__dict__)
53+
# Define a module that proxies getattr() and dir() to the result of calling
54+
# load_once() the first time it's needed. The class is nested so we can close
55+
# over load_once() and avoid polluting the module's attrs with our own state.
56+
class LazyModule(types.ModuleType):
57+
def __getattr__(self, attr_name):
58+
return getattr(load_once(self), attr_name)
4559

46-
return module
60+
def __dir__(self):
61+
return dir(load_once(self))
4762

48-
def __getattr__(self, item):
49-
module = self._load()
50-
return getattr(module, item)
63+
def __repr__(self):
64+
if load_once.loaded:
65+
return repr(load_once(self))
66+
return '<module \'%s\' (LazyModule)>' % self.__name__
5167

52-
def __dir__(self):
53-
module = self._load()
54-
return dir(module)
68+
return LazyModule(name)
69+
return wrapper
70+
71+
72+
def _memoize(f):
73+
"""Memoizing decorator for f, which must have exactly 1 hashable argument."""
74+
nothing = object() # Unique "no value" sentinel object.
75+
cache = {}
76+
# Use a reentrant lock so that if f references the resulting wrapper we die
77+
# with recursion depth exceeded instead of deadlocking.
78+
lock = threading.RLock()
79+
@functools.wraps(f)
80+
def wrapper(arg):
81+
if cache.get(arg, nothing) == nothing:
82+
with lock:
83+
if cache.get(arg, nothing) == nothing:
84+
cache[arg] = f(arg)
85+
return cache[arg]
86+
return wrapper

tensorboard/plugins/audio/summary_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@
2929
# TODO(nickfelt): get encode_wav() exported in the public API.
3030
from tensorflow.python.ops import gen_audio_ops
3131

32+
from tensorboard.compat import tf2
3233
from tensorboard.plugins.audio import metadata
3334
from tensorboard.plugins.audio import summary
3435
from tensorboard.util import tensor_util
3536

3637

3738
try:
38-
from tensorboard import compat
39-
tf_v2 = compat.import_tf_v2()
39+
tf2.__version__ # Force lazy import to resolve
4040
except ImportError:
41-
tf_v2 = None
41+
tf2 = None
4242

4343
try:
4444
tf.compat.v1.enable_eager_execution()
@@ -198,12 +198,12 @@ def test_requires_nonnegative_max_outputs(self):
198198
class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase):
199199
def setUp(self):
200200
super(SummaryV2OpTest, self).setUp()
201-
if tf_v2 is None:
201+
if tf2 is None:
202202
self.skipTest('TF v2 summary API not available')
203203

204204
def audio(self, *args, **kwargs):
205205
kwargs.setdefault('step', 1)
206-
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
206+
writer = tf2.summary.create_file_writer(self.get_temp_dir())
207207
with writer.as_default():
208208
summary.audio(*args, **kwargs)
209209
writer.close()

tensorboard/plugins/audio/summary_v2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import functools
2929

30+
from tensorboard.compat import tf2 as tf
3031
from tensorboard.plugins.audio import metadata
3132

3233

@@ -64,9 +65,6 @@ def audio(name,
6465
True on success, or false if no summary was emitted because no default
6566
summary writer was available.
6667
"""
67-
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
68-
from tensorboard import compat
69-
tf = compat.import_tf_v2()
7068
# TODO(nickfelt): get encode_wav() exported in the public API.
7169
from tensorflow.python.ops import gen_audio_ops
7270

tensorboard/plugins/histogram/summary_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
import numpy as np
2626
import tensorflow as tf
2727

28+
from tensorboard.compat import tf2
2829
from tensorboard.plugins.histogram import metadata
2930
from tensorboard.plugins.histogram import summary
3031
from tensorboard.util import tensor_util
3132

3233

3334
try:
34-
from tensorboard import compat
35-
tf_v2 = compat.import_tf_v2()
35+
tf2.__version__ # Force lazy import to resolve
3636
except ImportError:
37-
tf_v2 = None
37+
tf2 = None
3838

3939
try:
4040
tf.compat.v1.enable_eager_execution()
@@ -160,12 +160,12 @@ def histogram(self, *args, **kwargs):
160160
class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase):
161161
def setUp(self):
162162
super(SummaryV2OpTest, self).setUp()
163-
if tf_v2 is None:
163+
if tf2 is None:
164164
self.skipTest('v2 summary API not available')
165165

166166
def histogram(self, *args, **kwargs):
167167
kwargs.setdefault('step', 1)
168-
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
168+
writer = tf2.summary.create_file_writer(self.get_temp_dir())
169169
with writer.as_default():
170170
summary.histogram(*args, **kwargs)
171171
writer.close()
@@ -194,12 +194,12 @@ def histogram(self, *args, **kwargs):
194194
# Hack to extract current scope since there's no direct API for it.
195195
with tf.name_scope('_') as temp_scope:
196196
scope = temp_scope.rstrip('/_')
197-
@tf_v2.function
197+
@tf2.function
198198
def graph_fn():
199199
# Recreate the active scope inside the defun since it won't propagate.
200200
with tf.name_scope(scope):
201201
summary.histogram(*args, **kwargs)
202-
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
202+
writer = tf2.summary.create_file_writer(self.get_temp_dir())
203203
with writer.as_default():
204204
graph_fn()
205205
writer.close()

tensorboard/plugins/histogram/summary_v2.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import numpy as np
3333

34+
from tensorboard.compat import tf2 as tf
3435
from tensorboard.compat.proto import summary_pb2
3536
from tensorboard.plugins.histogram import metadata
3637
from tensorboard.util import tensor_util
@@ -59,9 +60,6 @@ def histogram(name, data, step, buckets=None, description=None):
5960
True on success, or false if no summary was emitted because no default
6061
summary writer was available.
6162
"""
62-
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
63-
from tensorboard import compat
64-
tf = compat.import_tf_v2()
6563
summary_metadata = metadata.create_summary_metadata(
6664
display_name=None, description=description)
6765
with tf.summary.summary_scope(
@@ -82,9 +80,6 @@ def _buckets(data, bucket_count=None):
8280
a triple `[left_edge, right_edge, count]` for a single bucket.
8381
The value of `k` is either `bucket_count` or `1` or `0`.
8482
"""
85-
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
86-
from tensorboard import compat
87-
tf = compat.import_tf_v2()
8883
if bucket_count is None:
8984
bucket_count = DEFAULT_BUCKET_COUNT
9085
with tf.name_scope('buckets', values=[data, bucket_count]):
@@ -152,8 +147,6 @@ def histogram_pb(tag, data, buckets=None, description=None):
152147
Returns:
153148
A `summary_pb2.Summary` protobuf object.
154149
"""
155-
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
156-
from tensorboard.compat import tf
157150
bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets
158151
data = np.array(data).flatten().astype(float)
159152
if data.size == 0:
@@ -179,7 +172,7 @@ def histogram_pb(tag, data, buckets=None, description=None):
179172
left_edges = edges[:-1]
180173
right_edges = edges[1:]
181174
buckets = np.array([left_edges, right_edges, bucket_counts]).transpose()
182-
tensor = tensor_util.make_tensor_proto(buckets, dtype=tf.float64)
175+
tensor = tensor_util.make_tensor_proto(buckets, dtype=np.float64)
183176

184177
summary_metadata = metadata.create_summary_metadata(
185178
display_name=None, description=description)

0 commit comments

Comments
 (0)