@@ -771,10 +771,9 @@ def test_basic_1(self):
771771 v = eval_outputs (max_and_argmax (n )[0 ].shape )
772772 assert len (v ) == 0
773773
774- def test_basic_2 (self ):
775- data = random (2 , 3 )
776- n = as_tensor_variable (data )
777- for (axis , np_axis ) in [
774+ @pytest .mark .parametrize (
775+ "axis,np_axis" ,
776+ [
778777 (- 1 , - 1 ),
779778 (0 , 0 ),
780779 (1 , 1 ),
@@ -783,19 +782,28 @@ def test_basic_2(self):
783782 ([1 , 0 ], None ),
784783 (NoneConst .clone (), None ),
785784 (constant (0 ), 0 ),
786- ]:
787- v , i = eval_outputs (max_and_argmax (n , axis ))
788- assert i .dtype == "int64"
789- assert np .all (v == np .max (data , np_axis ))
790- assert np .all (i == np .argmax (data , np_axis ))
791- v_shape = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
792- assert tuple (v_shape ) == np .max (data , np_axis ).shape
793-
794- def test_basic_2_float16 (self ):
795- # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
796- data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
797- n = shared (data )
798- for (axis , np_axis ) in [
785+ ],
786+ )
787+ def test_basic_2 (self , axis , np_axis ):
788+ data = random (2 , 3 )
789+ n = as_tensor_variable (data )
790+ # Test shape propagates (static & eval)
791+ vt , it = max_and_argmax (n , axis )
792+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
793+ assert vt .type .shape == np_max .shape
794+ assert it .type .shape == np_argm .shape
795+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
796+ assert tuple (v_shape ) == vt .type .shape
797+ assert tuple (i_shape ) == it .type .shape
798+ # Test values
799+ v , i = eval_outputs ([vt , it ])
800+ assert i .dtype == "int64"
801+ assert np .all (v == np_max )
802+ assert np .all (i == np_argm )
803+
804+ @pytest .mark .parametrize (
805+ "axis,np_axis" ,
806+ [
799807 (- 1 , - 1 ),
800808 (0 , 0 ),
801809 (1 , 1 ),
@@ -804,13 +812,25 @@ def test_basic_2_float16(self):
804812 ([1 , 0 ], None ),
805813 (NoneConst .clone (), None ),
806814 (constant (0 ), 0 ),
807- ]:
808- v , i = eval_outputs (max_and_argmax (n , axis ), (MaxAndArgmax ,))
809- assert i .dtype == "int64"
810- assert np .all (v == np .max (data , np_axis ))
811- assert np .all (i == np .argmax (data , np_axis ))
812- v_shape = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
813- assert tuple (v_shape ) == np .max (data , np_axis ).shape
815+ ],
816+ )
817+ def test_basic_2_float16 (self , axis , np_axis ):
818+ # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
819+ data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
820+ n = as_tensor_variable (data )
821+ # Test shape propagates (static & eval)
822+ vt , it = max_and_argmax (n , axis )
823+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
824+ assert vt .type .shape == np_max .shape
825+ assert it .type .shape == np_argm .shape
826+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
827+ assert tuple (v_shape ) == vt .type .shape
828+ assert tuple (i_shape ) == it .type .shape
829+ # Test values
830+ v , i = eval_outputs ([vt , it ])
831+ assert i .dtype == "int64"
832+ assert np .all (v == np_max )
833+ assert np .all (i == np_argm )
814834
815835 def test_basic_2_invalid (self ):
816836 n = as_tensor_variable (random (2 , 3 ))
@@ -840,23 +860,33 @@ def test_basic_2_valid_neg(self):
840860 v = eval_outputs (max_and_argmax (n , - 2 )[0 ].shape )
841861 assert v == (3 )
842862
843- def test_basic_3 (self ):
844- data = random (2 , 3 , 4 )
845- n = as_tensor_variable (data )
846- for (axis , np_axis ) in [
863+ @pytest .mark .parametrize (
864+ "axis,np_axis" ,
865+ [
847866 (- 1 , - 1 ),
848867 (0 , 0 ),
849868 (1 , 1 ),
850869 (None , None ),
851870 ([0 , 1 , 2 ], None ),
852871 ([1 , 2 , 0 ], None ),
853- ]:
854- v , i = eval_outputs (max_and_argmax (n , axis ))
855- assert i .dtype == "int64"
856- assert np .all (v == np .max (data , np_axis ))
857- assert np .all (i == np .argmax (data , np_axis ))
858- v = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
859- assert tuple (v ) == np .max (data , np_axis ).shape
872+ ],
873+ )
874+ def test_basic_3 (self , axis , np_axis ):
875+ data = random (2 , 3 , 4 )
876+ n = as_tensor_variable (data )
877+ # Test shape propagates (static & eval)
878+ vt , it = max_and_argmax (n , axis )
879+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
880+ assert vt .type .shape == np_max .shape
881+ assert it .type .shape == np_argm .shape
882+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
883+ assert tuple (v_shape ) == vt .type .shape
884+ assert tuple (i_shape ) == it .type .shape
885+ # Test values
886+ v , i = eval_outputs ([vt , it ])
887+ assert i .dtype == "int64"
888+ assert np .all (v == np_max )
889+ assert np .all (i == np_argm )
860890
861891 def test_arg_grad (self ):
862892 # The test checks that the gradient of argmax(x).sum() is 0
@@ -948,17 +978,19 @@ def test_preserve_broadcastable(self):
948978 # Ensure the original broadcastable flags are preserved by Max/Argmax.
949979 x = matrix ().dimshuffle ("x" , 0 , "x" , 1 , "x" )
950980 y = x .max (axis = 1 )
981+ assert y .type .shape == (1 , 1 , None , 1 )
951982 assert y .type .broadcastable == (True , True , False , True )
952983
953984 def test_multiple_axes (self ):
954985 data = np .arange (24 ).reshape (3 , 2 , 4 )
955986 x = as_tensor_variable (data )
956- v , i = eval_outputs (max_and_argmax (x , [1 , - 1 ]))
987+ vt , it = max_and_argmax (x , [1 , - 1 ])
988+ assert vt .type .shape == it .type .shape == (3 ,)
989+ v , i = eval_outputs ([vt , it ])
957990 assert np .all (v == np .array ([7 , 15 , 23 ]))
958991 assert np .all (i == np .array ([7 , 7 , 7 ]))
959-
960- v = eval_outputs (max_and_argmax (x , [1 , - 1 ])[0 ].shape )
961- assert tuple (v ) == np .max (data , (1 , - 1 )).shape
992+ v = eval_outputs (vt .shape )
993+ assert tuple (v ) == vt .type .shape
962994
963995 def test_zero_shape (self ):
964996 x = matrix ()
@@ -972,8 +1004,8 @@ def test_zero_shape(self):
9721004 def test_numpy_input (self ):
9731005 ar = np .array ([1 , 2 , 3 ])
9741006 max_at , argmax_at = max_and_argmax (ar , axis = None )
975- assert max_at .eval (), 3
976- assert argmax_at .eval (), 2
1007+ assert max_at .eval () == 3
1008+ assert argmax_at .eval () == 2
9771009
9781010
9791011class TestArgminArgmax :
0 commit comments