1
+ from functools import partial
2
+
1
3
import numpy as np
2
4
import pytest
3
5
from numpy .testing import assert_array_almost_equal
@@ -148,41 +150,51 @@ def test_qr_modes():
148
150
149
151
class TestSvd (utt .InferShapeTester ):
150
152
op_class = SVD
151
- dtype = "float32"
152
153
153
154
def setup_method (self ):
154
155
super ().setup_method ()
155
156
self .rng = np .random .default_rng (utt .fetch_seed ())
156
- self .A = matrix (dtype = self . dtype )
157
+ self .A = matrix (dtype = config . floatX )
157
158
self .op = svd
158
159
159
- def test_svd (self ):
160
- A = matrix ("A" , dtype = self .dtype )
161
- U , S , VT = svd (A )
162
- fn = function ([A ], [U , S , VT ])
163
- a = self .rng .random ((4 , 4 )).astype (self .dtype )
164
- n_u , n_s , n_vt = np .linalg .svd (a )
165
- t_u , t_s , t_vt = fn (a )
166
-
167
- assert _allclose (n_u , t_u )
168
- assert _allclose (n_s , t_s )
169
- assert _allclose (n_vt , t_vt )
170
-
171
- fn = function ([A ], svd (A , compute_uv = False ))
172
- t_s = fn (a )
173
- assert _allclose (n_s , t_s )
174
-
175
- A = tensor3 ("A" , dtype = self .dtype )
176
- U , S , VT = svd (A )
177
- fn = function ([A ], [U , S , VT ])
178
- a = self .rng .random ((10 , 4 , 4 )).astype (self .dtype )
179
- t_u , t_s , t_vt = fn (a )
180
- n_u , n_s , n_vt = np .vectorize (
181
- np .linalg .svd , signature = "(i,j)->(i,k),(k),(k,j)"
182
- )(a )
183
- assert _allclose (n_u , t_u )
184
- assert _allclose (n_s , t_s )
185
- assert _allclose (n_vt , t_vt )
160
+ @pytest .mark .parametrize (
161
+ "compute_uv" , [True , False ], ids = ["compute_uv=True" , "compute_uv=False" ]
162
+ )
163
+ @pytest .mark .parametrize (
164
+ "batched" , [True , False ], ids = ["batched=True" , "batched=False" ]
165
+ )
166
+ @pytest .mark .parametrize (
167
+ "test_imag" , [True , False ], ids = ["test_imag=True" , "test_imag=False" ]
168
+ )
169
+ def test_svd (self , compute_uv , batched , test_imag ):
170
+ dtype = config .floatX
171
+ if test_imag :
172
+ dtype = "complex128" if dtype .endswith ("64" ) else "complex64"
173
+
174
+ if batched :
175
+ A = tensor3 ("A" , dtype = dtype )
176
+ size = (10 , 4 , 4 )
177
+ else :
178
+ A = matrix ("A" , dtype = dtype )
179
+ size = (4 , 4 )
180
+ a = self .rng .random (size ).astype (dtype )
181
+
182
+ outputs = svd (A , compute_uv = compute_uv , full_matrices = False )
183
+ outputs = outputs if isinstance (outputs , list ) else [outputs ]
184
+ fn = function (inputs = [A ], outputs = outputs )
185
+
186
+ np_fn = np .vectorize (
187
+ partial (np .linalg .svd , compute_uv = compute_uv , full_matrices = False ),
188
+ signature = outputs [0 ].owner .op .core_op .gufunc_signature ,
189
+ )
190
+
191
+ np_outputs = np_fn (a )
192
+ pt_outputs = fn (a )
193
+
194
+ np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
195
+
196
+ for np_val , pt_val in zip (np_outputs , pt_outputs ):
197
+ assert _allclose (np_val , pt_val )
186
198
187
199
def test_svd_infer_shape (self ):
188
200
self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -193,7 +205,7 @@ def test_svd_infer_shape(self):
193
205
194
206
def validate_shape (self , shape , compute_uv = True , full_matrices = True ):
195
207
A = self .A
196
- A_v = self .rng .random (shape ).astype (self . dtype )
208
+ A_v = self .rng .random (shape ).astype (config . floatX )
197
209
outputs = self .op (A , full_matrices = full_matrices , compute_uv = compute_uv )
198
210
if not compute_uv :
199
211
outputs = [outputs ]
@@ -465,9 +477,9 @@ def test_non_tensorial_input(self):
465
477
[None , np .inf , - np .inf , 1 , - 1 , 2 , - 2 ],
466
478
ids = ["None" , "inf" , "-inf" , "1" , "-1" , "2" , "-2" ],
467
479
)
468
- @pytest .mark .parametrize ("core_dims" , [(4 ,), (4 , 4 )], ids = ["vector" , "matrix" ])
480
+ @pytest .mark .parametrize ("core_dims" , [(4 ,), (4 , 3 )], ids = ["vector" , "matrix" ])
469
481
@pytest .mark .parametrize ("batch_dims" , [(), (2 ,)], ids = ["no_batch" , "batch" ])
470
- @pytest .mark .parametrize ("test_imag" , [True , False ], ids = ["real " , "complex " ])
482
+ @pytest .mark .parametrize ("test_imag" , [True , False ], ids = ["complex " , "real " ])
471
483
def test_numpy_compare (
472
484
self ,
473
485
ord : float ,
@@ -481,6 +493,8 @@ def test_numpy_compare(
481
493
has_batch = len (batch_dims ) > 0
482
494
if ord in [np .inf , - np .inf ] and not is_matrix :
483
495
pytest .skip ("Infinity norm not defined for vectors" )
496
+ if test_imag and is_matrix and ord == - 2 :
497
+ pytest .skip ("Complex matrices not supported" )
484
498
if has_batch and not is_matrix :
485
499
# Handle batched vectors by row-normalizing a matrix
486
500
axis = (- 1 ,)
@@ -491,8 +505,8 @@ def test_numpy_compare(
491
505
x_real , x_imag = rng .standard_normal ((2 ,) + batch_dims + core_dims ).astype (
492
506
config .floatX
493
507
)
494
- dtype = "complex64 " if config .floatX .endswith ("64" ) else "complex32 "
495
- X = x_real . astype ( dtype ) + 1j * x_imag .astype (dtype )
508
+ dtype = "complex128 " if config .floatX .endswith ("64" ) else "complex64 "
509
+ X = ( x_real + 1j * x_imag ) .astype (dtype )
496
510
else :
497
511
X = rng .standard_normal (batch_dims + core_dims ).astype (config .floatX )
498
512
@@ -505,6 +519,7 @@ def test_numpy_compare(
505
519
506
520
pt_norm = norm (X , ord = ord , axis = axis , keepdims = keepdims )
507
521
f = function ([], pt_norm , mode = "FAST_COMPILE" )
522
+
508
523
utt .assert_allclose (np_norm , f ())
509
524
510
525
0 commit comments