Skip to content

Commit

Permalink
implement Tensor.remaining_dims
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 10, 2023
1 parent af1548a commit 124a6d2
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,21 @@ def dim_tags_set_implicit(self):
dims.update(self.dim_tags_set_implicit_only)
return dims

def remaining_dims(self: _t.Tensor, remove: Optional[Union[Dim, Sequence[Dim]]] = None) -> List[Dim]:
"""
:param remove: dims to remove from self.dims
:return: ordered batch dims
"""
batch_dims = list(self._dims)
if not remove:
pass
elif isinstance(remove, Dim):
batch_dims.remove(remove)
else:
for remove_ in remove:
batch_dims.remove(remove_)
return batch_dims

@property
def ndim(self):
"""
Expand Down

0 comments on commit 124a6d2

Please sign in to comment.