diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 83a127b3c5..f3e559a18d 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -1284,7 +1284,7 @@ def grad(self, inp, grads): return [grad_undefined(self, i, inp[i]) for i in range(3)] -def eye(n, m=None, k=0, dtype=None): +def eye(n: int, m: int = None, k: int = 0, dtype=None) -> TensorVariable: """Return a 2-D array with ones on the diagonal and zeros elsewhere. Parameters @@ -1302,17 +1302,41 @@ def eye(n, m=None, k=0, dtype=None): Returns ------- - ndarray of shape (N,M) - An array where all elements are equal to zero, except for the `k`-th - diagonal, whose values are equal to one. - + aesara tensor of shape (N,M) + A symbolic tensor representing a matrix where all elements are equal to zero, + except for the `k`-th diagonal, whose values are equal to one. """ - if dtype is None: - dtype = config.floatX + if m is None: m = n - localop = Eye(dtype) - return localop(n, m, k) + if dtype is None: + dtype = aesara.config.floatX + + n = aesara.scalar.as_scalar(n) + m = aesara.scalar.as_scalar(m) + k = aesara.scalar.as_scalar(k) + + i = aesara.scalar.switch(k >= 0, k, -k * m) + i_comp_op = aesara.scalar.Composite([n, m, k], [i]) + i_comp = i_comp_op(n, m, k) + + mkm = (m - k) * m + mkm_comp_op = aesara.scalar.Composite([m, k], [mkm]) + mkm_comp = mkm_comp_op(m, k) + + last_row = aesara.scalar.switch(m - k > 0, m - k, 0) + last_row_op = aesara.scalar.Composite([m, k], [last_row]) + last_valid_row = last_row_op(m, k) + + eye = zeros(n * m, dtype=dtype) + + ones_slice = slice(i_comp, mkm_comp, m + 1) + overflow_rows = slice(last_valid_row, None, None) + + eye = aesara.tensor.subtensor.set_subtensor(eye[ones_slice], 1).reshape((n, m)) + eye = aesara.tensor.subtensor.set_subtensor(eye[overflow_rows, :], 0) + + return eye def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 509b651085..a3c7aa2bba 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -822,11 +822,14 @@ def check(dtype, N, M_=None, k=0): # allowed. if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]: M = N + N_symb = iscalar() M_symb = iscalar() k_symb = iscalar() + f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)) result = f(N, M, k) + assert np.allclose(result, np.eye(N, M_, k, dtype=dtype)) assert result.dtype == np.dtype(dtype)