Skip to content

Commit

Permalink
add torch flip with support for single axis
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 18, 2024
1 parent a09d42f commit 7b858ab
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,20 @@ def torch_indices(dimensions):
return _meshgrid(*map(_arange, dimensions), indexing="ij")


def torch_flip_wrap(torch_flip):
def numpy_like(x, axis=None):
if axis is None:
dims = tuple(range(x.ndimension()))
elif isinstance(axis, int):
dims = (axis,)
else:
# already tuple/list
dims = axis
return torch_flip(x, dims)

return numpy_like


_FUNCS["torch", "pad"] = torch_pad
_FUNCS["torch", "real"] = torch_real
_FUNCS["torch", "imag"] = torch_imag
Expand Down Expand Up @@ -2125,6 +2139,7 @@ def torch_indices(dimensions):
[("a", ("input",)), ("axis", ("dim",))]
)
_CUSTOM_WRAPPERS["torch", "sort"] = torch_sort_wrap
_CUSTOM_WRAPPERS["torch", "flip"] = torch_flip_wrap

# for older versions of torch, can provide some alternative implementations
_MODULE_ALIASES["torch[alt]"] = "torch"
Expand Down

0 comments on commit 7b858ab

Please sign in to comment.