Skip to content

Commit

Permalink
fix Tensor not hashable, DictRefKeys, RefIdEq
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 16, 2023
1 parent e0541e4 commit de9fcb4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
14 changes: 7 additions & 7 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict
from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict, List

from returnn.util.basic import Entity
from returnn.util import basic as util
Expand Down Expand Up @@ -757,7 +757,7 @@ def is_same_size_tensor(self, x):
return True
if not self._extra:
return False
if util.TensorRef(x) in self._extra.dyn_size_same:
if util.RefIdEq(x) in self._extra.dyn_size_same:
return True
return False

Expand Down Expand Up @@ -794,10 +794,10 @@ def set_tag_on_size_tensor(self: Dim, x, batch=None, same_as_before=False) -> Di
new_dim_tag.set_tag_on_size_tensor(x, batch=batch)
return new_dim_tag
if self.dyn_size is not None and self.dyn_size is not x:
if self._extra and util.TensorRef(x) in self._extra.dyn_size_same:
if self._extra and util.RefIdEq(x) in self._extra.dyn_size_same:
pass # ok, pass on
elif same_as_before:
self._make_extra().dyn_size_same.add(util.TensorRef(x))
self._make_extra().dyn_size_same.add(util.RefIdEq(x))
# And now pass on.
else:
assert self.batch and batch
Expand Down Expand Up @@ -1448,14 +1448,14 @@ def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
@classmethod
def get_all_dimension_tags(cls, data_list, is_equal_opts=None, unique_separate_axes=True):
"""
:param list[Data] data_list:
:param list[_t.Tensor] data_list:
:param dict[str]|None is_equal_opts: passed to Dim.is_equal
:param bool unique_separate_axes: e.g. data_list=[Data with shape (B,5,5,10)] results in 4 dim tags, not 3.
:return: list of dimension tags, dict for data -> list of dimension tags (for each axis)
:rtype: (list[Dim], dict[Data, list[Dim]])
:rtype: (list[Dim], util.DictRefKeys[_t.Tensor, list[Dim]])
"""
tags = []
data_axes_dict = {}
data_axes_dict = util.DictRefKeys() # type: util.DictRefKeys[_t.Tensor, List[Dim]]
for data in data_list:
data_axes_dict[data] = []
existing_tag_collection_for_data = list(tags) if unique_separate_axes else tags
Expand Down
66 changes: 54 additions & 12 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from __future__ import annotations
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Iterable, Tuple

import subprocess
from subprocess import CalledProcessError
Expand Down Expand Up @@ -63,6 +63,8 @@


T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")

returnn_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

Expand Down Expand Up @@ -2343,35 +2345,75 @@ def make_hashable(obj):
import tensorflow as tf

if isinstance(obj, tf.Tensor):
return TensorRef(obj)
return RefIdEq(obj)
assert False, "don't know how to make hashable: %r (%r)" % (obj, type(obj))


class TensorRef(Generic[T]):
class RefIdEq(Generic[T]):
"""
Reference to the original tensor, which is hashable.
We have this here for compatibility because tf.Tensor.ref() was not available in earlier TF versions.
Reference to some object (e.g. t.fTensor), but this object is always hashable,
and uses the `id` of the function for the hash and equality.
(In case of tf.Tensor, this is for compatibility
because tf.Tensor.ref() was not available in earlier TF versions.
However, we also need this for :class:`DictRefKeys`.)
"""

def __init__(self, tensor: T):
def __init__(self, obj: T):
"""
:param tensor: for example tf.Tensor
:param obj: for example tf.Tensor
"""
self.tensor = tensor
self.obj = obj

def __repr__(self):
return "TensorRef{%r}" % self.tensor
return "TensorRef{%r}" % self.obj

def __eq__(self, other):
if other is None or not isinstance(other, TensorRef):
if other is None or not isinstance(other, RefIdEq):
return False
return self.tensor is other.tensor
return self.obj is other.obj

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
return id(self.tensor)
return id(self.obj)


class DictRefKeys(Generic[K, V]):
"""
Like `dict`, but hash and equality of the keys
"""

def __init__(self):
self._d = {} # type: dict[RefIdEq[K], V]

def __repr__(self):
return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])

def items(self) -> Iterable[Tuple[K, V]]:
"""items"""
for k, v in self._d.items():
yield k.obj, v

def keys(self) -> Iterable[K]:
"""keys"""
for k in self._d.keys():
yield k.obj

def values(self) -> Iterable[V]:
"""values"""
for v in self._d.values():
yield v

def __getitem__(self, item: K) -> V:
return self._d[RefIdEq(item)]

def __setitem__(self, key: K, value: V):
self._d[RefIdEq(key)] = value

def __contains__(self, item: K):
return RefIdEq(item) in self._d


def make_dll_name(basename):
Expand Down

0 comments on commit de9fcb4

Please sign in to comment.