diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 0f5c1b2fe0..37622a8294 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,7 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join @singledispatch @@ -100,3 +100,19 @@ def join(axis, *tensors): return torch.cat(tensors, dim=axis) return join + + +@pytorch_funcify.register(Eye) +def pytorch_funcify_eye(op, **kwargs): + torch_dtype = getattr(torch, op.dtype) + + def eye(N, M, k): + major, minor = (M, N) if k > 0 else (N, M) + k_abs = torch.abs(k) + zeros = torch.zeros(N, M, dtype=torch_dtype) + if k_abs < major: + l_ones = torch.min(major - k_abs, minor) + return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k) + return zeros + + return eye diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index c6750361a7..0ccb1c454f 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -13,7 +13,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise -from pytensor.tensor import alloc, arange, as_tensor, empty +from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor.type import matrix, scalar, vector @@ -275,3 +275,22 @@ def test_pytorch_Join(): np.c_[[5.0, 6.0]].astype(config.floatX), ], ) + + +@pytest.mark.parametrize( + "dtype", + ["int64", config.floatX], +) +def test_eye(dtype): + N = scalar("N", dtype="int64") + M = scalar("M", dtype="int64") + k = scalar("k", dtype="int64") + + out = eye(N, M, k, dtype=dtype) + + fn = function([N, M, k], out, mode=pytorch_mode) + + for _N in range(1, 6): + for _M in range(1, 6): + for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: + np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))