Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,22 @@ def _get_matching_collections(

if match_inc(ts_name, current_coll.include_regex):
# In TF 2.x eager mode, we can't put tensors in a set/dictionary as tensor.__hash__()
# is no longer available. tensor.experimental_ref() returns a hashable reference
# is no longer available. tensor.ref() returns a hashable reference
# object to this Tensor.
if is_tf_version_2x() and tf.executing_eagerly():
# tensor.experimental_ref is an experimental API
# and can be changed or removed.
# Ref: https://www.tensorflow.org/api_docs/python/tf/Tensor#experimental_ref
tensor = tensor.experimental_ref()
if hasattr(tensor, "ref"):
# See: https://www.tensorflow.org/api_docs/python/tf/Tensor#ref
# experimental_ref is being deprecated for ref
tensor = tensor.ref()
elif hasattr(tensor, "experimental_ref"):
# tensor.experimental_ref is an experimental API
# and can be changed or removed.
# Ref: https://www.tensorflow.org/api_docs/python/tf/Tensor#experimental_ref
tensor = tensor.experimental_ref()
else:
raise Exception(
"Neither ref nor experimental_ref API present. Check TF version"
)
if not current_coll.has_tensor(tensor):
# tensor will be added to this coll below
colls_with_tensor.add(current_coll)
Expand Down