1+ from functools import partial
2+
13import numpy as np
24import pytest
35from numpy .testing import assert_array_almost_equal
@@ -148,41 +150,51 @@ def test_qr_modes():
148150
149151class TestSvd (utt .InferShapeTester ):
150152 op_class = SVD
151- dtype = "float32"
152153
153154 def setup_method (self ):
154155 super ().setup_method ()
155156 self .rng = np .random .default_rng (utt .fetch_seed ())
156- self .A = matrix (dtype = self . dtype )
157+ self .A = matrix (dtype = config . floatX )
157158 self .op = svd
158159
159- def test_svd (self ):
160- A = matrix ("A" , dtype = self .dtype )
161- U , S , VT = svd (A )
162- fn = function ([A ], [U , S , VT ])
163- a = self .rng .random ((4 , 4 )).astype (self .dtype )
164- n_u , n_s , n_vt = np .linalg .svd (a )
165- t_u , t_s , t_vt = fn (a )
166-
167- assert _allclose (n_u , t_u )
168- assert _allclose (n_s , t_s )
169- assert _allclose (n_vt , t_vt )
170-
171- fn = function ([A ], svd (A , compute_uv = False ))
172- t_s = fn (a )
173- assert _allclose (n_s , t_s )
174-
175- A = tensor3 ("A" , dtype = self .dtype )
176- U , S , VT = svd (A )
177- fn = function ([A ], [U , S , VT ])
178- a = self .rng .random ((10 , 4 , 4 )).astype (self .dtype )
179- t_u , t_s , t_vt = fn (a )
180- n_u , n_s , n_vt = np .vectorize (
181- np .linalg .svd , signature = "(i,j)->(i,k),(k),(k,j)"
182- )(a )
183- assert _allclose (n_u , t_u )
184- assert _allclose (n_s , t_s )
185- assert _allclose (n_vt , t_vt )
160+ @pytest .mark .parametrize (
161+ "compute_uv" , [True , False ], ids = ["compute_uv=True" , "compute_uv=False" ]
162+ )
163+ @pytest .mark .parametrize (
164+ "batched" , [True , False ], ids = ["batched=True" , "batched=False" ]
165+ )
166+ @pytest .mark .parametrize (
167+ "test_imag" , [True , False ], ids = ["test_imag=True" , "test_imag=False" ]
168+ )
169+ def test_svd (self , compute_uv , batched , test_imag ):
170+ dtype = config .floatX
171+ if test_imag :
172+ dtype = "complex128" if dtype .endswith ("64" ) else "complex64"
173+
174+ if batched :
175+ A = tensor3 ("A" , dtype = dtype )
176+ size = (10 , 4 , 4 )
177+ else :
178+ A = matrix ("A" , dtype = dtype )
179+ size = (4 , 4 )
180+ a = self .rng .random (size ).astype (dtype )
181+
182+ outputs = svd (A , compute_uv = compute_uv , full_matrices = False )
183+ outputs = outputs if isinstance (outputs , list ) else [outputs ]
184+ fn = function (inputs = [A ], outputs = outputs )
185+
186+ np_fn = np .vectorize (
187+ partial (np .linalg .svd , compute_uv = compute_uv , full_matrices = False ),
188+ signature = outputs [0 ].owner .op .core_op .gufunc_signature ,
189+ )
190+
191+ np_outputs = np_fn (a )
192+ pt_outputs = fn (a )
193+
194+ np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
195+
196+ for np_val , pt_val in zip (np_outputs , pt_outputs ):
197+ assert _allclose (np_val , pt_val )
186198
187199 def test_svd_infer_shape (self ):
188200 self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -193,7 +205,7 @@ def test_svd_infer_shape(self):
193205
194206 def validate_shape (self , shape , compute_uv = True , full_matrices = True ):
195207 A = self .A
196- A_v = self .rng .random (shape ).astype (self . dtype )
208+ A_v = self .rng .random (shape ).astype (config . floatX )
197209 outputs = self .op (A , full_matrices = full_matrices , compute_uv = compute_uv )
198210 if not compute_uv :
199211 outputs = [outputs ]
@@ -465,9 +477,9 @@ def test_non_tensorial_input(self):
465477 [None , np .inf , - np .inf , 1 , - 1 , 2 , - 2 ],
466478 ids = ["None" , "inf" , "-inf" , "1" , "-1" , "2" , "-2" ],
467479 )
468- @pytest .mark .parametrize ("core_dims" , [(4 ,), (4 , 4 )], ids = ["vector" , "matrix" ])
480+ @pytest .mark .parametrize ("core_dims" , [(4 ,), (4 , 3 )], ids = ["vector" , "matrix" ])
469481 @pytest .mark .parametrize ("batch_dims" , [(), (2 ,)], ids = ["no_batch" , "batch" ])
470- @pytest .mark .parametrize ("test_imag" , [True , False ], ids = ["real " , "complex " ])
482+ @pytest .mark .parametrize ("test_imag" , [True , False ], ids = ["complex " , "real " ])
471483 def test_numpy_compare (
472484 self ,
473485 ord : float ,
@@ -481,6 +493,8 @@ def test_numpy_compare(
481493 has_batch = len (batch_dims ) > 0
482494 if ord in [np .inf , - np .inf ] and not is_matrix :
483495 pytest .skip ("Infinity norm not defined for vectors" )
496+ if test_imag and is_matrix and ord == - 2 :
497+ pytest .skip ("Complex matrices not supported" )
484498 if has_batch and not is_matrix :
485499 # Handle batched vectors by row-normalizing a matrix
486500 axis = (- 1 ,)
@@ -491,8 +505,8 @@ def test_numpy_compare(
491505 x_real , x_imag = rng .standard_normal ((2 ,) + batch_dims + core_dims ).astype (
492506 config .floatX
493507 )
494- dtype = "complex64 " if config .floatX .endswith ("64" ) else "complex32 "
495- X = x_real . astype ( dtype ) + 1j * x_imag .astype (dtype )
508+ dtype = "complex128 " if config .floatX .endswith ("64" ) else "complex64 "
509+ X = ( x_real + 1j * x_imag ) .astype (dtype )
496510 else :
497511 X = rng .standard_normal (batch_dims + core_dims ).astype (config .floatX )
498512
@@ -505,6 +519,7 @@ def test_numpy_compare(
505519
506520 pt_norm = norm (X , ord = ord , axis = axis , keepdims = keepdims )
507521 f = function ([], pt_norm , mode = "FAST_COMPILE" )
522+
508523 utt .assert_allclose (np_norm , f ())
509524
510525
0 commit comments