Skip to content

Commit d8f0e32

Browse files
Refactor TestSvd and TestNormTests to handle complex and batched cases
1 parent 97c9504 commit d8f0e32

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def make_node(self, x):
559559
assert x.ndim == 2, "The input of svd function should be a matrix."
560560

561561
in_dtype = x.type.numpy_dtype
562-
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
562+
out_dtype = in_dtype
563563

564564
s = vector(dtype=out_dtype)
565565

tests/tensor/test_nlinalg.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pytest
35
from numpy.testing import assert_array_almost_equal
@@ -148,41 +150,51 @@ def test_qr_modes():
148150

149151
class TestSvd(utt.InferShapeTester):
150152
op_class = SVD
151-
dtype = "float32"
152153

153154
def setup_method(self):
154155
super().setup_method()
155156
self.rng = np.random.default_rng(utt.fetch_seed())
156-
self.A = matrix(dtype=self.dtype)
157+
self.A = matrix(dtype=config.floatX)
157158
self.op = svd
158159

159-
def test_svd(self):
160-
A = matrix("A", dtype=self.dtype)
161-
U, S, VT = svd(A)
162-
fn = function([A], [U, S, VT])
163-
a = self.rng.random((4, 4)).astype(self.dtype)
164-
n_u, n_s, n_vt = np.linalg.svd(a)
165-
t_u, t_s, t_vt = fn(a)
166-
167-
assert _allclose(n_u, t_u)
168-
assert _allclose(n_s, t_s)
169-
assert _allclose(n_vt, t_vt)
170-
171-
fn = function([A], svd(A, compute_uv=False))
172-
t_s = fn(a)
173-
assert _allclose(n_s, t_s)
174-
175-
A = tensor3("A", dtype=self.dtype)
176-
U, S, VT = svd(A)
177-
fn = function([A], [U, S, VT])
178-
a = self.rng.random((10, 4, 4)).astype(self.dtype)
179-
t_u, t_s, t_vt = fn(a)
180-
n_u, n_s, n_vt = np.vectorize(
181-
np.linalg.svd, signature="(i,j)->(i,k),(k),(k,j)"
182-
)(a)
183-
assert _allclose(n_u, t_u)
184-
assert _allclose(n_s, t_s)
185-
assert _allclose(n_vt, t_vt)
160+
@pytest.mark.parametrize(
161+
"compute_uv", [True, False], ids=["compute_uv=True", "compute_uv=False"]
162+
)
163+
@pytest.mark.parametrize(
164+
"batched", [True, False], ids=["batched=True", "batched=False"]
165+
)
166+
@pytest.mark.parametrize(
167+
"test_imag", [True, False], ids=["test_imag=True", "test_imag=False"]
168+
)
169+
def test_svd(self, compute_uv, batched, test_imag):
170+
dtype = config.floatX
171+
if test_imag:
172+
dtype = "complex128" if dtype.endswith("64") else "complex64"
173+
174+
if batched:
175+
A = tensor3("A", dtype=dtype)
176+
size = (10, 4, 4)
177+
else:
178+
A = matrix("A", dtype=dtype)
179+
size = (4, 4)
180+
a = self.rng.random(size).astype(dtype)
181+
182+
outputs = svd(A, compute_uv=compute_uv, full_matrices=False)
183+
outputs = outputs if isinstance(outputs, list) else [outputs]
184+
fn = function(inputs=[A], outputs=outputs)
185+
186+
np_fn = np.vectorize(
187+
partial(np.linalg.svd, compute_uv=compute_uv, full_matrices=False),
188+
signature=outputs[0].owner.op.core_op.gufunc_signature,
189+
)
190+
191+
np_outputs = np_fn(a)
192+
pt_outputs = fn(a)
193+
194+
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
195+
196+
for np_val, pt_val in zip(np_outputs, pt_outputs):
197+
assert _allclose(np_val, pt_val)
186198

187199
def test_svd_infer_shape(self):
188200
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
@@ -193,7 +205,7 @@ def test_svd_infer_shape(self):
193205

194206
def validate_shape(self, shape, compute_uv=True, full_matrices=True):
195207
A = self.A
196-
A_v = self.rng.random(shape).astype(self.dtype)
208+
A_v = self.rng.random(shape).astype(config.floatX)
197209
outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv)
198210
if not compute_uv:
199211
outputs = [outputs]
@@ -465,9 +477,9 @@ def test_non_tensorial_input(self):
465477
[None, np.inf, -np.inf, 1, -1, 2, -2],
466478
ids=["None", "inf", "-inf", "1", "-1", "2", "-2"],
467479
)
468-
@pytest.mark.parametrize("core_dims", [(4,), (4, 4)], ids=["vector", "matrix"])
480+
@pytest.mark.parametrize("core_dims", [(4,), (4, 3)], ids=["vector", "matrix"])
469481
@pytest.mark.parametrize("batch_dims", [(), (2,)], ids=["no_batch", "batch"])
470-
@pytest.mark.parametrize("test_imag", [True, False], ids=["real", "complex"])
482+
@pytest.mark.parametrize("test_imag", [True, False], ids=["complex", "real"])
471483
def test_numpy_compare(
472484
self,
473485
ord: float,
@@ -481,6 +493,8 @@ def test_numpy_compare(
481493
has_batch = len(batch_dims) > 0
482494
if ord in [np.inf, -np.inf] and not is_matrix:
483495
pytest.skip("Infinity norm not defined for vectors")
496+
if test_imag and is_matrix and ord == -2:
497+
pytest.skip("Complex matrices not supported")
484498
if has_batch and not is_matrix:
485499
# Handle batched vectors by row-normalizing a matrix
486500
axis = (-1,)
@@ -491,8 +505,8 @@ def test_numpy_compare(
491505
x_real, x_imag = rng.standard_normal((2,) + batch_dims + core_dims).astype(
492506
config.floatX
493507
)
494-
dtype = "complex64" if config.floatX.endswith("64") else "complex32"
495-
X = x_real.astype(dtype) + 1j * x_imag.astype(dtype)
508+
dtype = "complex128" if config.floatX.endswith("64") else "complex64"
509+
X = (x_real + 1j * x_imag).astype(dtype)
496510
else:
497511
X = rng.standard_normal(batch_dims + core_dims).astype(config.floatX)
498512

@@ -505,6 +519,7 @@ def test_numpy_compare(
505519

506520
pt_norm = norm(X, ord=ord, axis=axis, keepdims=keepdims)
507521
f = function([], pt_norm, mode="FAST_COMPILE")
522+
508523
utt.assert_allclose(np_norm, f())
509524

510525

0 commit comments

Comments
 (0)