Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions tensorboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

from tensorboard import lazy

pkg = lambda i: i # helps google sync process
mod = lambda i: lazy.LazyLoader(i[i.rindex('.') + 1:], globals(), i) # noqa: F821

program = mod(pkg('tensorboard.program'))
summary = mod(pkg('tensorboard.summary'))
@lazy.lazy_load('tensorboard.program')
def program():
import tensorboard.program as module # pylint: disable=g-import-not-at-top
return module

del lazy
del mod
del pkg

@lazy.lazy_load('tensorboard.summary')
def summary():
import tensorboard.summary as module # pylint: disable=g-import-not-at-top
return module
3 changes: 3 additions & 0 deletions tensorboard/compat/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorboard:lazy",
],
)

# This rule ensures that `from tensorboard.compat import tf` will provide a
Expand Down
86 changes: 43 additions & 43 deletions tensorboard/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,59 @@

"""Compatibility interfaces for TensorBoard.

This module provides logic for importing variations on the TensorFlow APIs.

The alias `tf` is for the main TF API used by TensorBoard. By default this will
be the result of `import tensorflow as tf`, or undefined if that fails. This
can be used in combination with //tensorboard/compat:tensorflow (to fall back to
a stub TF API implementation if the real one is not available) and
//tensorboard/compat:no_tensorflow (to use the stub TF API unconditionally).

The function `import_tf_v2` provides common logic for importing the TF 2.0 API,
and returns the root module of the API if found, or else raises ImportError.
This is a function instead of a direct alias like `tf` in order to provide
enough indirection to get around circular dependencies.
This module provides logic for importing variations on the TensorFlow APIs, as
lazily loaded imports to help avoid circular dependency issues and defer the
search and loading of the module until necessary.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# First, check if using TF is explicitly disabled by request.
USING_TF = True
try:
from tensorboard.compat import notf
USING_TF = False
except ImportError:
pass
import importlib as _importlib

import tensorboard.lazy as _lazy

# If TF is not disabled, check if it's available.
if USING_TF:
try:
import tensorflow as tf
except ImportError:
USING_TF = False

if not USING_TF:
# If we can't use TF, try to provide the stub instead.
# This will only work if the tensorflow_stub dep is included
# in the build, via the `tensorboard/compat:tensorflow` target.
@_lazy.lazy_load('tensorboard.compat.tf')
def tf():
"""Provide the root module of a TF-like API for use within TensorBoard.

By default this is equivalent to `import tensorflow as tf`, but it can be used
in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF
API implementation if the real one is not available) or with
//tensorboard/compat:no_tensorflow (to force unconditional use of the stub).

Returns:
The root module of a TF-like API, if available.

Raises:
ImportError: if a TF-like API is not available.
"""
try:
from tensorboard.compat import tensorflow_stub as tf
_importlib.import_module('tensorboard.compat.notf')
except ImportError:
pass
try:
return _importlib.import_module('tensorflow')
except ImportError:
pass
return _importlib.import_module('tensorboard.compat.tensorflow_stub') # pylint: disable=line-too-long


@_lazy.lazy_load('tensorboard.compat.tf2')
def tf2():
"""Provide the root module of a TF-2.0 API for use within TensorBoard.

Returns:
The root module of a TF-2.0 API, if available.

def import_tf_v2():
"""Import the TF 2.0 API if possible, or raise an ImportError."""
# We must be able to use TF in order to provide the TF 2.0 API.
if USING_TF:
# Check if this is TF 2.0 by looking for a known 2.0-only tf.summary symbol.
# TODO(nickfelt): determine a cleaner way to do this.
if hasattr(tf, 'summary') and hasattr(tf.summary, 'write'):
return tf
else:
# As a fallback, try `tensorflow.compat.v2` if it's defined.
if hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'):
return tf.compat.v2
Raises:
ImportError: if a TF-2.0 API is not available.
"""
# Import the `tf` compat API from this file and check if it's already TF 2.0.
if tf.__version__.startswith('2.'):
return tf
elif hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'):
# As a fallback, try `tensorflow.compat.v2` if it's defined.
return tf.compat.v2
raise ImportError('cannot import tensorflow 2.0 API')
3 changes: 3 additions & 0 deletions tensorboard/compat/tensorflow_stub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@
from . import gfile # noqa
from . import pywrap_tensorflow # noqa
from . import tensor_shape # noqa

# Set a fake __version__ to help distinguish this as our own stub API.
__version__ = 'stub'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it strange that this makes tf.__version__ and tf.version.VERSION
different? Do we want to change the latter as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stub doesn't actually have a version module so I'm not quite sure I follow? I picked __version__ because it's the one place that's consistent across TF 1.x and 2.x, since 2.0 only has tf.version.VERSION, but most of 1.x used tf.VERSION.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. Thanks.

78 changes: 55 additions & 23 deletions tensorboard/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,68 @@
from __future__ import division
from __future__ import print_function

import importlib
import functools
import threading
import types


class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
def lazy_load(name):
"""Decorator to define a function that lazily loads the module 'name'.

# The lint error here is incorrect.
def __init__(self, local_name, parent_module_globals, name): # pylint: disable=super-on-old-class
self._local_name = local_name
self._parent_module_globals = parent_module_globals
This can be used to defer importing troublesome dependencies - e.g. ones that
are large and infrequently used, or that cause a dependency cycle -
until they are actually used.

super(LazyLoader, self).__init__(name)
Args:
name: the fully-qualified name of the module; typically the last segment
of 'name' matches the name of the decorated function

def _load(self):
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
Returns:
Decorator function that produces a lazy-loading module 'name' backed by the
underlying decorated function.
"""
def wrapper(load_fn):
# Wrap load_fn to call it exactly once and update __dict__ afterwards to
# make future lookups efficient (only failed lookups call __getattr__).
@_memoize
def load_once(self):
module = load_fn()
self.__dict__.update(module.__dict__)
load_once.loaded = True
return module
load_once.loaded = False

# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
# that fail).
self.__dict__.update(module.__dict__)
# Define a module that proxies getattr() and dir() to the result of calling
# load_once() the first time it's needed. The class is nested so we can close
# over load_once() and avoid polluting the module's attrs with our own state.
class LazyModule(types.ModuleType):
def __getattr__(self, attr_name):
return getattr(load_once(self), attr_name)

return module
def __dir__(self):
return dir(load_once(self))

def __getattr__(self, item):
module = self._load()
return getattr(module, item)
def __repr__(self):
if load_once.loaded:
return repr(load_once(self))
return '<module \'%s\' (LazyModule)>' % self.__name__

def __dir__(self):
module = self._load()
return dir(module)
return LazyModule(name)
return wrapper


def _memoize(f):
"""Memoizing decorator for f, which must have exactly 1 hashable argument."""
nothing = object() # Unique "no value" sentinel object.
cache = {}
# Use a reentrant lock so that if f references the resulting wrapper we die
# with recursion depth exceeded instead of deadlocking.
lock = threading.RLock()
@functools.wraps(f)
def wrapper(arg):
if cache.get(arg, nothing) == nothing:
with lock:
if cache.get(arg, nothing) == nothing:
cache[arg] = f(arg)
return cache[arg]
return wrapper
10 changes: 5 additions & 5 deletions tensorboard/plugins/audio/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
# TODO(nickfelt): get encode_wav() exported in the public API.
from tensorflow.python.ops import gen_audio_ops

from tensorboard.compat import tf2
from tensorboard.plugins.audio import metadata
from tensorboard.plugins.audio import summary
from tensorboard.util import tensor_util


try:
from tensorboard import compat
tf_v2 = compat.import_tf_v2()
tf2.__version__ # Force lazy import to resolve
except ImportError:
tf_v2 = None
tf2 = None

try:
tf.compat.v1.enable_eager_execution()
Expand Down Expand Up @@ -198,12 +198,12 @@ def test_requires_nonnegative_max_outputs(self):
class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase):
def setUp(self):
super(SummaryV2OpTest, self).setUp()
if tf_v2 is None:
if tf2 is None:
self.skipTest('TF v2 summary API not available')

def audio(self, *args, **kwargs):
kwargs.setdefault('step', 1)
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
summary.audio(*args, **kwargs)
writer.close()
Expand Down
4 changes: 1 addition & 3 deletions tensorboard/plugins/audio/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import functools

from tensorboard.compat import tf2 as tf
from tensorboard.plugins.audio import metadata


Expand Down Expand Up @@ -64,9 +65,6 @@ def audio(name,
True on success, or false if no summary was emitted because no default
summary writer was available.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
from tensorboard import compat
tf = compat.import_tf_v2()
# TODO(nickfelt): get encode_wav() exported in the public API.
from tensorflow.python.ops import gen_audio_ops

Expand Down
14 changes: 7 additions & 7 deletions tensorboard/plugins/histogram/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@
import numpy as np
import tensorflow as tf

from tensorboard.compat import tf2
from tensorboard.plugins.histogram import metadata
from tensorboard.plugins.histogram import summary
from tensorboard.util import tensor_util


try:
from tensorboard import compat
tf_v2 = compat.import_tf_v2()
tf2.__version__ # Force lazy import to resolve
except ImportError:
tf_v2 = None
tf2 = None

try:
tf.compat.v1.enable_eager_execution()
Expand Down Expand Up @@ -160,12 +160,12 @@ def histogram(self, *args, **kwargs):
class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase):
def setUp(self):
super(SummaryV2OpTest, self).setUp()
if tf_v2 is None:
if tf2 is None:
self.skipTest('v2 summary API not available')

def histogram(self, *args, **kwargs):
kwargs.setdefault('step', 1)
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
summary.histogram(*args, **kwargs)
writer.close()
Expand Down Expand Up @@ -194,12 +194,12 @@ def histogram(self, *args, **kwargs):
# Hack to extract current scope since there's no direct API for it.
with tf.name_scope('_') as temp_scope:
scope = temp_scope.rstrip('/_')
@tf_v2.function
@tf2.function
def graph_fn():
# Recreate the active scope inside the defun since it won't propagate.
with tf.name_scope(scope):
summary.histogram(*args, **kwargs)
writer = tf_v2.summary.create_file_writer(self.get_temp_dir())
writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
graph_fn()
writer.close()
Expand Down
11 changes: 2 additions & 9 deletions tensorboard/plugins/histogram/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import numpy as np

from tensorboard.compat import tf2 as tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import metadata
from tensorboard.util import tensor_util
Expand Down Expand Up @@ -59,9 +60,6 @@ def histogram(name, data, step, buckets=None, description=None):
True on success, or false if no summary was emitted because no default
summary writer was available.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
from tensorboard import compat
tf = compat.import_tf_v2()
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description)
with tf.summary.summary_scope(
Expand All @@ -82,9 +80,6 @@ def _buckets(data, bucket_count=None):
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `1` or `0`.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
from tensorboard import compat
tf = compat.import_tf_v2()
if bucket_count is None:
bucket_count = DEFAULT_BUCKET_COUNT
with tf.name_scope('buckets', values=[data, bucket_count]):
Expand Down Expand Up @@ -152,8 +147,6 @@ def histogram_pb(tag, data, buckets=None, description=None):
Returns:
A `summary_pb2.Summary` protobuf object.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
from tensorboard.compat import tf
bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets
data = np.array(data).flatten().astype(float)
if data.size == 0:
Expand All @@ -179,7 +172,7 @@ def histogram_pb(tag, data, buckets=None, description=None):
left_edges = edges[:-1]
right_edges = edges[1:]
buckets = np.array([left_edges, right_edges, bucket_counts]).transpose()
tensor = tensor_util.make_tensor_proto(buckets, dtype=tf.float64)
tensor = tensor_util.make_tensor_proto(buckets, dtype=np.float64)

summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description)
Expand Down
Loading