Skip to content

Commit

Permalink
Implemented Eye Op in PyTorch (#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw authored Jul 7, 2024
1 parent ca10298 commit 4ea96b2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
18 changes: 17 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join


@singledispatch
Expand Down Expand Up @@ -100,3 +100,19 @@ def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

return join


@pytorch_funcify.register(Eye)
def pytorch_funcify_eye(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)

def eye(N, M, k):
major, minor = (M, N) if k > 0 else (N, M)
k_abs = torch.abs(k)
zeros = torch.zeros(N, M, dtype=torch_dtype)
if k_abs < major:
l_ones = torch.min(major - k_abs, minor)
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k)
return zeros

return eye
21 changes: 20 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector


Expand Down Expand Up @@ -275,3 +275,22 @@ def test_pytorch_Join():
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)


@pytest.mark.parametrize(
"dtype",
["int64", config.floatX],
)
def test_eye(dtype):
N = scalar("N", dtype="int64")
M = scalar("M", dtype="int64")
k = scalar("k", dtype="int64")

out = eye(N, M, k, dtype=dtype)

fn = function([N, M, k], out, mode=pytorch_mode)

for _N in range(1, 6):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))

0 comments on commit 4ea96b2

Please sign in to comment.