@@ -231,56 +231,5 @@ def test_check_grad(self):
231231 pass
232232
233233
234- class TestMinMaxWithIndexPlace (unittest .TestCase ):
235- """min/max_with_index has no CPU version, so when CUDA is not available,
236- we skip all the above test. A runtime error will be emitted if min/max_with_index
237- is called on CPU, this unit test tries capturing it.
238- """
239-
240- def init (self ):
241- self .input_shape = [30 , 10 , 10 ]
242- self .data = np .random .randn (30 , 10 , 10 )
243-
244- def setUp (self ):
245- self .init ()
246-
247- def cpu_place (self ):
248- self .place = core .CPUPlace ()
249-
250- def test_api_static_cpu_err_handling_1 (self ):
251- self .cpu_place ()
252- with (
253- self .assertRaises (RuntimeError ),
254- paddle .static .program_guard (paddle .static .Program ()),
255- ):
256- input = paddle .static .data (
257- name = "input" , shape = self .input_shape , dtype = "float64"
258- )
259- output = max_with_index (input , dim = 0 )
260- exe = paddle .static .Executor (self .place )
261- result = exe .run (
262- paddle .static .default_main_program (),
263- feed = {'input' : self .data },
264- fetch_list = [output ],
265- )
266-
267- def test_api_static_cpu_err_handling_2 (self ):
268- self .cpu_place ()
269- with (
270- self .assertRaises (RuntimeError ),
271- paddle .static .program_guard (paddle .static .Program ()),
272- ):
273- input = paddle .static .data (
274- name = "input" , shape = self .input_shape , dtype = "float32"
275- )
276- output = min_with_index (input , dim = - 2 , keepdim = True )
277- exe = paddle .static .Executor (self .place )
278- result = exe .run (
279- paddle .static .default_main_program (),
280- feed = {'input' : self .data .astype (np .float32 )},
281- fetch_list = [output ],
282- )
283-
284-
285234if __name__ == "__main__" :
286235 unittest .main ()
0 commit comments