diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py
index bf7f926259..5c3355de81 100644
--- a/returnn/tensor/_dim_extra.py
+++ b/returnn/tensor/_dim_extra.py
@@ -4,7 +4,7 @@
 """
 
 from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict
+from typing import TYPE_CHECKING, Optional, Union, Tuple, Dict, List
 
 from returnn.util.basic import Entity
 from returnn.util import basic as util
@@ -757,7 +757,7 @@ def is_same_size_tensor(self, x):
             return True
         if not self._extra:
             return False
-        if util.TensorRef(x) in self._extra.dyn_size_same:
+        if util.RefIdEq(x) in self._extra.dyn_size_same:
             return True
         return False
 
@@ -794,10 +794,10 @@ def set_tag_on_size_tensor(self: Dim, x, batch=None, same_as_before=False) -> Di
             new_dim_tag.set_tag_on_size_tensor(x, batch=batch)
             return new_dim_tag
         if self.dyn_size is not None and self.dyn_size is not x:
-            if self._extra and util.TensorRef(x) in self._extra.dyn_size_same:
+            if self._extra and util.RefIdEq(x) in self._extra.dyn_size_same:
                 pass  # ok, pass on
             elif same_as_before:
-                self._make_extra().dyn_size_same.add(util.TensorRef(x))
+                self._make_extra().dyn_size_same.add(util.RefIdEq(x))
                 # And now pass on.
             else:
                 assert self.batch and batch
@@ -1448,14 +1448,14 @@ def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
     @classmethod
     def get_all_dimension_tags(cls, data_list, is_equal_opts=None, unique_separate_axes=True):
         """
-        :param list[Data] data_list:
+        :param list[_t.Tensor] data_list:
         :param dict[str]|None is_equal_opts: passed to Dim.is_equal
         :param bool unique_separate_axes: e.g. data_list=[Data with shape (B,5,5,10)] results in 4 dim tags, not 3.
         :return: list of dimension tags, dict for data -> list of dimension tags (for each axis)
-        :rtype: (list[Dim], dict[Data, list[Dim]])
+        :rtype: (list[Dim], util.DictRefKeys[_t.Tensor, list[Dim]])
         """
         tags = []
-        data_axes_dict = {}
+        data_axes_dict = util.DictRefKeys()  # type: util.DictRefKeys[_t.Tensor, List[Dim]]
         for data in data_list:
             data_axes_dict[data] = []
             existing_tag_collection_for_data = list(tags) if unique_separate_axes else tags
diff --git a/returnn/tensor/_tensor_op_overloads.py b/returnn/tensor/_tensor_op_overloads.py
index 4461d88594..fbf0899c4a 100644
--- a/returnn/tensor/_tensor_op_overloads.py
+++ b/returnn/tensor/_tensor_op_overloads.py
@@ -16,13 +16,8 @@ class _TensorOpOverloadsMixin(_TensorMixinBase):
 
     # --- comparisons
 
-    # _TensorMixin.__eq__ is disabled as per the following error in some TF tests:
-    # AssertionError: unhashable type: 'Tensor'.
-    # See CI https://github.com/rwth-i6/returnn/actions/runs/4406240591
-    """
     def __eq__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor:
         return self.raw_frontend.compare(self, "==", other)
-    """
 
     def __ne__(self: Tensor, other: Union[_frontend_api.RawTensorTypes, Tensor]) -> Tensor:
         return self.raw_frontend.compare(self, "!=", other)
diff --git a/returnn/util/basic.py b/returnn/util/basic.py
index 2ed3679a9f..68af829311 100644
--- a/returnn/util/basic.py
+++ b/returnn/util/basic.py
@@ -5,7 +5,7 @@
 """
 
 from __future__ import annotations
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, Iterable, Tuple
 
 import subprocess
 from subprocess import CalledProcessError
@@ -63,6 +63,8 @@
 
 
 T = TypeVar("T")
+K = TypeVar("K")
+V = TypeVar("V")
 
 returnn_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
@@ -2343,35 +2345,75 @@ def make_hashable(obj):
         import tensorflow as tf
 
         if isinstance(obj, tf.Tensor):
-            return TensorRef(obj)
+            return RefIdEq(obj)
     assert False, "don't know how to make hashable: %r (%r)" % (obj, type(obj))
 
 
-class TensorRef(Generic[T]):
+class RefIdEq(Generic[T]):
     """
-    Reference to the original tensor, which is hashable.
-    We have this here for compatibility because tf.Tensor.ref() was not available in earlier TF versions.
+    Reference to some object (e.g. t.fTensor), but this object is always hashable,
+    and uses the `id` of the function for the hash and equality.
+
+    (In case of tf.Tensor, this is for compatibility
+     because tf.Tensor.ref() was not available in earlier TF versions.
+     However, we also need this for :class:`DictRefKeys`.)
     """
 
-    def __init__(self, tensor: T):
+    def __init__(self, obj: T):
         """
-        :param tensor: for example tf.Tensor
+        :param obj: for example tf.Tensor
         """
-        self.tensor = tensor
+        self.obj = obj
 
     def __repr__(self):
-        return "TensorRef{%r}" % self.tensor
+        return "TensorRef{%r}" % self.obj
 
     def __eq__(self, other):
-        if other is None or not isinstance(other, TensorRef):
+        if other is None or not isinstance(other, RefIdEq):
             return False
-        return self.tensor is other.tensor
+        return self.obj is other.obj
 
     def __ne__(self, other):
         return not self.__eq__(other)
 
     def __hash__(self):
-        return id(self.tensor)
+        return id(self.obj)
+
+
+class DictRefKeys(Generic[K, V]):
+    """
+    Like `dict`, but hash and equality of the keys
+    """
+
+    def __init__(self):
+        self._d = {}  # type: dict[RefIdEq[K], V]
+
+    def __repr__(self):
+        return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])
+
+    def items(self) -> Iterable[Tuple[K, V]]:
+        """items"""
+        for k, v in self._d.items():
+            yield k.obj, v
+
+    def keys(self) -> Iterable[K]:
+        """keys"""
+        for k in self._d.keys():
+            yield k.obj
+
+    def values(self) -> Iterable[V]:
+        """values"""
+        for v in self._d.values():
+            yield v
+
+    def __getitem__(self, item: K) -> V:
+        return self._d[RefIdEq(item)]
+
+    def __setitem__(self, key: K, value: V):
+        self._d[RefIdEq(key)] = value
+
+    def __contains__(self, item: K):
+        return RefIdEq(item) in self._d
 
 
 def make_dll_name(basename):