From 124a6d2f2cf0112cda9ef0e6dbdf5f0888ded8d7 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 10 Mar 2023 14:47:40 +0100 Subject: [PATCH] implement Tensor.remaining_dims https://github.com/rwth-i6/returnn_common/issues/252 --- returnn/tensor/_tensor_extra.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index e51d9d5dbf..9fd6ad2ce5 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -1695,6 +1695,21 @@ def dim_tags_set_implicit(self): dims.update(self.dim_tags_set_implicit_only) return dims + def remaining_dims(self: _t.Tensor, remove: Optional[Union[Dim, Sequence[Dim]]] = None) -> List[Dim]: + """ + :param remove: dims to remove from self.dims + :return: ordered batch dims + """ + batch_dims = list(self._dims) + if not remove: + pass + elif isinstance(remove, Dim): + batch_dims.remove(remove) + else: + for remove_ in remove: + batch_dims.remove(remove_) + return batch_dims + @property def ndim(self): """