From 80e32df236f6cb425a611732e1302de6ea444f15 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 16 Mar 2023 11:58:34 +0100 Subject: [PATCH 1/2] Tensor eq --- returnn/tensor/_tensor_op_overloads.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/returnn/tensor/_tensor_op_overloads.py b/returnn/tensor/_tensor_op_overloads.py index 4461d88594..fbf0899c4a 100644 --- a/returnn/tensor/_tensor_op_overloads.py +++ b/returnn/tensor/_tensor_op_overloads.py @@ -16,13 +16,8 @@ class _TensorOpOverloadsMixin(_TensorMixinBase): # --- comparisons - # _TensorMixin.__eq__ is disabled as per the following error in some TF tests: - # AssertionError: unhashable type: 'Tensor'. - # See CI https://github.com/rwth-i6/returnn/actions/runs/4406240591 - """ def __eq__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor: return self.raw_frontend.compare(self, "==", other) - """ def __ne__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor: return self.raw_frontend.compare(self, "!=", other) From 544f3e18a3f3a00301f8fa3c32e653276822745e Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 16 Mar 2023 12:12:55 +0100 Subject: [PATCH 2/2] fix Tensor not hashable, DictRefKeys, RefIdEq --- returnn/tensor/_dim_extra.py | 14 ++++---- returnn/util/basic.py | 66 +++++++++++++++++++++++++++++------- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index bf7f926259..5c3355de81 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -4,7 +4,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict +from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict, List from returnn.util.basic import Entity from returnn.util import basic as util @@ -757,7 +757,7 @@ def is_same_size_tensor(self, x): return True if not self._extra: return False - if util.TensorRef(x) in self._extra.dyn_size_same: + if util.RefIdEq(x) in self._extra.dyn_size_same: return True return False @@ -794,10 +794,10 @@ def set_tag_on_size_tensor(self: Dim, x, batch=None, same_as_before=False) -> Di new_dim_tag.set_tag_on_size_tensor(x, batch=batch) return new_dim_tag if self.dyn_size is not None and self.dyn_size is not x: - if self._extra and util.TensorRef(x) in self._extra.dyn_size_same: + if self._extra and util.RefIdEq(x) in self._extra.dyn_size_same: pass # ok, pass on elif same_as_before: - self._make_extra().dyn_size_same.add(util.TensorRef(x)) + self._make_extra().dyn_size_same.add(util.RefIdEq(x)) # And now pass on. else: assert self.batch and batch @@ -1448,14 +1448,14 @@ def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None): @classmethod def get_all_dimension_tags(cls, data_list, is_equal_opts=None, unique_separate_axes=True): """ - :param list[Data] data_list: + :param list[_t.Tensor] data_list: :param dict[str]|None is_equal_opts: passed to Dim.is_equal :param bool unique_separate_axes: e.g. data_list=[Data with shape (B,5,5,10)] results in 4 dim tags, not 3. :return: list of dimension tags, dict for data -> list of dimension tags (for each axis) - :rtype: (list[Dim], dict[Data, list[Dim]]) + :rtype: (list[Dim], util.DictRefKeys[_t.Tensor, list[Dim]]) """ tags = [] - data_axes_dict = {} + data_axes_dict = util.DictRefKeys() # type: util.DictRefKeys[_t.Tensor, List[Dim]] for data in data_list: data_axes_dict[data] = [] existing_tag_collection_for_data = list(tags) if unique_separate_axes else tags diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 2ed3679a9f..68af829311 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Iterable, Tuple import subprocess from subprocess import CalledProcessError @@ -63,6 +63,8 @@ T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") returnn_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -2343,35 +2345,75 @@ def make_hashable(obj): import tensorflow as tf if isinstance(obj, tf.Tensor): - return TensorRef(obj) + return RefIdEq(obj) assert False, "don't know how to make hashable: %r (%r)" % (obj, type(obj)) -class TensorRef(Generic[T]): +class RefIdEq(Generic[T]): """ - Reference to the original tensor, which is hashable. - We have this here for compatibility because tf.Tensor.ref() was not available in earlier TF versions. + Reference to some object (e.g. t.fTensor), but this object is always hashable, + and uses the `id` of the function for the hash and equality. + + (In case of tf.Tensor, this is for compatibility + because tf.Tensor.ref() was not available in earlier TF versions. + However, we also need this for :class:`DictRefKeys`.) """ - def __init__(self, tensor: T): + def __init__(self, obj: T): """ - :param tensor: for example tf.Tensor + :param obj: for example tf.Tensor """ - self.tensor = tensor + self.obj = obj def __repr__(self): - return "TensorRef{%r}" % self.tensor + return "TensorRef{%r}" % self.obj def __eq__(self, other): - if other is None or not isinstance(other, TensorRef): + if other is None or not isinstance(other, RefIdEq): return False - return self.tensor is other.tensor + return self.obj is other.obj def __ne__(self, other): return not self.__eq__(other) def __hash__(self): - return id(self.tensor) + return id(self.obj) + + +class DictRefKeys(Generic[K, V]): + """ + Like `dict`, but hash and equality of the keys + """ + + def __init__(self): + self._d = {} # type: dict[RefIdEq[K], V] + + def __repr__(self): + return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()]) + + def items(self) -> Iterable[Tuple[K, V]]: + """items""" + for k, v in self._d.items(): + yield k.obj, v + + def keys(self) -> Iterable[K]: + """keys""" + for k in self._d.keys(): + yield k.obj + + def values(self) -> Iterable[V]: + """values""" + for v in self._d.values(): + yield v + + def __getitem__(self, item: K) -> V: + return self._d[RefIdEq(item)] + + def __setitem__(self, key: K, value: V): + self._d[RefIdEq(key)] = value + + def __contains__(self, item: K): + return RefIdEq(item) in self._d def make_dll_name(basename):