1717import numpy as np
1818
1919import paddle
20- from paddle .compat import sort
20+ from paddle .compat import sort as compat_sort
2121
2222
2323class TestCompatSort (unittest .TestCase ):
2424
2525 def _compare_with_origin (
26- self , input_tensor , dtype , dim , descending , stable
26+ self , input_tensor , dtype , dim , descending , stable , use_out = False
2727 ):
28- sort_res = sort (
29- input_tensor , dim = dim , descending = descending , stable = stable
30- )
28+ """DO NOT set use_out to be True in static graph mode."""
29+ if use_out :
30+ sort_res = (paddle .to_tensor (0 ), paddle .to_tensor (0 ))
31+ compat_sort (input_tensor , dim , descending , stable , out = sort_res )
32+ else :
33+ sort_res = compat_sort (
34+ input_tensor , dim = dim , descending = descending , stable = stable
35+ )
3136
3237 origin_vals = paddle .sort (
3338 input_tensor , axis = dim , descending = descending , stable = stable
@@ -37,15 +42,11 @@ def _compare_with_origin(
3742 )
3843 if dtype .find ("int" ):
3944 np .testing .assert_array_equal (
40- sort_res . values .numpy (), origin_vals .numpy ()
45+ sort_res [ 0 ] .numpy (), origin_vals .numpy ()
4146 )
4247 else :
43- np .testing .assert_allclose (
44- sort_res .values .numpy (), origin_vals .numpy ()
45- )
46- np .testing .assert_array_equal (
47- sort_res .indices .numpy (), origin_inds .numpy ()
48- )
48+ np .testing .assert_allclose (sort_res [0 ].numpy (), origin_vals .numpy ())
49+ np .testing .assert_array_equal (sort_res [1 ].numpy (), origin_inds .numpy ())
4950
5051 def test_with_origin_static (self ):
5152 dtypes = [
@@ -75,7 +76,7 @@ def static_graph_tester(descending, stable):
7576 input_data = paddle .static .data (
7677 name = 'x' , shape = shape , dtype = dtype
7778 )
78- sort_res = sort (
79+ sort_res = compat_sort (
7980 input_data ,
8081 dim = dim ,
8182 descending = descending ,
@@ -149,19 +150,40 @@ def test_with_origin_dynamic(self, use_static=False):
149150 input_tensor = paddle .randint (0 , 255 , shape ).to (dtype )
150151 else :
151152 input_tensor = paddle .randn (shape , dtype = dtype )
152- for dim in range (len (shape )):
153- self ._compare_with_origin (
154- input_tensor , dtype , dim , False , False
155- )
156- self ._compare_with_origin (
157- input_tensor , dtype , dim - len (shape ), False , True
158- )
159- self ._compare_with_origin (
160- input_tensor , dtype , dim , True , False
161- )
162- self ._compare_with_origin (
163- input_tensor , dtype , dim - len (shape ), True , True
164- )
153+ for use_out in [False , True ]:
154+ for dim in range (len (shape )):
155+ self ._compare_with_origin (
156+ input_tensor ,
157+ dtype ,
158+ dim ,
159+ False ,
160+ False ,
161+ use_out = use_out ,
162+ )
163+ self ._compare_with_origin (
164+ input_tensor ,
165+ dtype ,
166+ dim - len (shape ),
167+ False ,
168+ True ,
169+ use_out = use_out ,
170+ )
171+ self ._compare_with_origin (
172+ input_tensor ,
173+ dtype ,
174+ dim ,
175+ True ,
176+ False ,
177+ use_out = use_out ,
178+ )
179+ self ._compare_with_origin (
180+ input_tensor ,
181+ dtype ,
182+ dim - len (shape ),
183+ True ,
184+ True ,
185+ use_out = use_out ,
186+ )
165187
166188 def test_sort_backward (self ):
167189 """test the backward behavior for all data types"""
@@ -177,7 +199,7 @@ def test_sort_backward(self):
177199 y = input_tensor * input_tensor
178200 else :
179201 y = input_tensor + 1
180- sort_vals , sort_inds = sort (y , dim = dim )
202+ sort_vals , sort_inds = compat_sort (y , dim = dim )
181203 sort_vals .backward ()
182204 if input_tensor .place .is_gpu_place ():
183205 np .testing .assert_allclose (
@@ -194,7 +216,7 @@ def test_sort_backward(self):
194216 def test_edge_cases (self ):
195217 """Test edge cases and error handling"""
196218 x = paddle .to_tensor ([])
197- sort_res = sort (x , descending = True , stable = True )
219+ sort_res = compat_sort (x , descending = True , stable = True )
198220
199221 np .testing .assert_array_equal (
200222 sort_res .values .numpy (), np .array ([], dtype = np .float32 )
@@ -204,7 +226,7 @@ def test_edge_cases(self):
204226 )
205227
206228 x = paddle .to_tensor (1 )
207- sort_res = sort (input = x , stable = True )
229+ sort_res = compat_sort (input = x , stable = True )
208230
209231 np .testing .assert_array_equal (
210232 sort_res .values .numpy (), np .array (1 , dtype = np .float32 )
@@ -213,8 +235,8 @@ def test_edge_cases(self):
213235 sort_res .indices .numpy (), np .array (0 , dtype = np .int64 )
214236 )
215237
216- msg_gt_1 = "paddle.sort() received unexpected keyword arguments 'input ', 'dim '. \n Did you mean to use paddle.compat.sort() instead?"
217- msg_gt_2 = "paddle.compat.sort() received unexpected keyword arguments 'x ', 'axis '. \n Did you mean to use paddle.sort() instead?"
238+ msg_gt_1 = "paddle.sort() received unexpected keyword arguments 'dim ', 'input '. \n Did you mean to use paddle.compat.sort() instead?"
239+ msg_gt_2 = "paddle.compat.sort() received unexpected keyword arguments 'axis ', 'x '. \n Did you mean to use paddle.sort() instead?"
218240
219241 # invalid split sections
220242 with self .assertRaises (TypeError ) as cm :
@@ -223,9 +245,44 @@ def test_edge_cases(self):
223245
224246 # invalid split axis
225247 with self .assertRaises (TypeError ) as cm :
226- sort (x = paddle .to_tensor ([2 , 1 , 3 ]), axis = 0 )
248+ compat_sort (x = paddle .to_tensor ([2 , 1 , 3 ]), axis = 0 )
227249 self .assertEqual (str (cm .exception ), msg_gt_2 )
228250
251+ def test_wrong_out_input (dim , out_input ):
252+ with self .assertRaises (TypeError ) as cm :
253+ compat_sort (paddle .to_tensor ([1 , 2 ]), out = out_input )
254+
255+ test_wrong_out_input (0 , [0 , paddle .to_tensor (0 )])
256+ test_wrong_out_input (0 , paddle .to_tensor (0 ))
257+ test_wrong_out_input (None , 0 )
258+ test_wrong_out_input (None , (paddle .to_tensor (0 ),))
259+
260+ paddle .enable_static ()
261+ with (
262+ self .assertRaises (RuntimeError ) as cm ,
263+ paddle .static .program_guard (paddle .static .Program ()),
264+ ):
265+ x = paddle .static .data (name = 'x' , shape = [None , 6 ], dtype = 'float32' )
266+ result0 , result1 = compat_sort (
267+ paddle .arange (24 ),
268+ out = (
269+ paddle .zeros ([24 ]),
270+ paddle .zeros ([24 ], dtype = paddle .int64 ),
271+ ),
272+ )
273+
274+ place = (
275+ paddle .CUDAPlace (0 )
276+ if paddle .is_compiled_with_cuda ()
277+ else paddle .CPUPlace ()
278+ )
279+ paddle .static .Executor (place ).run ()
280+ self .assertEqual (
281+ str (cm .exception ),
282+ "Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n " ,
283+ )
284+ paddle .disable_static ()
285+
229286
230287if __name__ == "__main__" :
231288 unittest .main ()
0 commit comments