Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor eq #1276

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict
from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict, List

from returnn.util.basic import Entity
from returnn.util import basic as util
Expand Down Expand Up @@ -757,7 +757,7 @@ def is_same_size_tensor(self, x):
return True
if not self._extra:
return False
if util.TensorRef(x) in self._extra.dyn_size_same:
if util.RefIdEq(x) in self._extra.dyn_size_same:
return True
return False

Expand Down Expand Up @@ -794,10 +794,10 @@ def set_tag_on_size_tensor(self: Dim, x, batch=None, same_as_before=False) -> Di
new_dim_tag.set_tag_on_size_tensor(x, batch=batch)
return new_dim_tag
if self.dyn_size is not None and self.dyn_size is not x:
if self._extra and util.TensorRef(x) in self._extra.dyn_size_same:
if self._extra and util.RefIdEq(x) in self._extra.dyn_size_same:
pass # ok, pass on
elif same_as_before:
self._make_extra().dyn_size_same.add(util.TensorRef(x))
self._make_extra().dyn_size_same.add(util.RefIdEq(x))
# And now pass on.
else:
assert self.batch and batch
Expand Down Expand Up @@ -1448,14 +1448,14 @@ def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
@classmethod
def get_all_dimension_tags(cls, data_list, is_equal_opts=None, unique_separate_axes=True):
"""
:param list[Data] data_list:
:param list[_t.Tensor] data_list:
:param dict[str]|None is_equal_opts: passed to Dim.is_equal
:param bool unique_separate_axes: e.g. data_list=[Data with shape (B,5,5,10)] results in 4 dim tags, not 3.
:return: list of dimension tags, dict for data -> list of dimension tags (for each axis)
:rtype: (list[Dim], dict[Data, list[Dim]])
:rtype: (list[Dim], util.DictRefKeys[_t.Tensor, list[Dim]])
"""
tags = []
data_axes_dict = {}
data_axes_dict = util.DictRefKeys() # type: util.DictRefKeys[_t.Tensor, List[Dim]]
for data in data_list:
data_axes_dict[data] = []
existing_tag_collection_for_data = list(tags) if unique_separate_axes else tags
Expand Down
5 changes: 0 additions & 5 deletions returnn/tensor/_tensor_op_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@ class _TensorOpOverloadsMixin(_TensorMixinBase):

# --- comparisons

# _TensorMixin.__eq__ is disabled as per the following error in some TF tests:
# AssertionError: unhashable type: 'Tensor'.
# See CI https://github.com/rwth-i6/returnn/actions/runs/4406240591
"""
def __eq__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor:
return self.raw_frontend.compare(self, "==", other)
"""

def __ne__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor:
return self.raw_frontend.compare(self, "!=", other)
Expand Down
66 changes: 54 additions & 12 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from __future__ import annotations
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Iterable, Tuple

import subprocess
from subprocess import CalledProcessError
Expand Down Expand Up @@ -63,6 +63,8 @@


T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")

returnn_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

Expand Down Expand Up @@ -2343,35 +2345,75 @@ def make_hashable(obj):
import tensorflow as tf

if isinstance(obj, tf.Tensor):
return TensorRef(obj)
return RefIdEq(obj)
assert False, "don't know how to make hashable: %r (%r)" % (obj, type(obj))


class TensorRef(Generic[T]):
class RefIdEq(Generic[T]):
"""
Reference to the original tensor, which is hashable.
We have this here for compatibility because tf.Tensor.ref() was not available in earlier TF versions.
Reference to some object (e.g. t.fTensor), but this object is always hashable,
and uses the `id` of the function for the hash and equality.

(In case of tf.Tensor, this is for compatibility
because tf.Tensor.ref() was not available in earlier TF versions.
However, we also need this for :class:`DictRefKeys`.)
"""

def __init__(self, tensor: T):
def __init__(self, obj: T):
"""
:param tensor: for example tf.Tensor
:param obj: for example tf.Tensor
"""
self.tensor = tensor
self.obj = obj

def __repr__(self):
return "TensorRef{%r}" % self.tensor
return "TensorRef{%r}" % self.obj

def __eq__(self, other):
if other is None or not isinstance(other, TensorRef):
if other is None or not isinstance(other, RefIdEq):
return False
return self.tensor is other.tensor
return self.obj is other.obj

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
return id(self.tensor)
return id(self.obj)


class DictRefKeys(Generic[K, V]):
"""
Like `dict`, but hash and equality of the keys
"""

def __init__(self):
self._d = {} # type: dict[RefIdEq[K], V]

def __repr__(self):
return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])

def items(self) -> Iterable[Tuple[K, V]]:
"""items"""
for k, v in self._d.items():
yield k.obj, v

def keys(self) -> Iterable[K]:
"""keys"""
for k in self._d.keys():
yield k.obj

def values(self) -> Iterable[V]:
"""values"""
for v in self._d.values():
yield v

def __getitem__(self, item: K) -> V:
return self._d[RefIdEq(item)]

def __setitem__(self, key: K, value: V):
self._d[RefIdEq(key)] = value

def __contains__(self, item: K):
return RefIdEq(item) in self._d


def make_dll_name(basename):
Expand Down