@@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
193193 input = dpt .empty ((height , width ), dtype = input_type , device = device )
194194 output = dpt .empty (width , dtype = output_type , device = device )
195195
196- if func (input , output ):
197- print (output_type )
198196 assert func (input , output ) is None
199197
200198
201199@pytest .mark .parametrize (
202200 "func, device, input_type, output_type" ,
203201 product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 ]),
204202)
205- def test_mean_over_axis_0_f_contig_input (func , device , input_type , output_type ):
203+ def test_mean_sum_over_axis_0_f_contig_input (
204+ func , device , input_type , output_type
205+ ):
206206 skip_unsupported (device , input_type )
207207 skip_unsupported (device , output_type )
208208
@@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
212212 input = dpt .empty ((height , width ), dtype = input_type , device = device ).T
213213 output = dpt .empty (width , dtype = output_type , device = device )
214214
215- if func (input , output ):
216- print (output_type )
217215 assert func (input , output ) is None
218216
219217
220218@pytest .mark .parametrize (
221219 "func, device, input_type, output_type" ,
222220 product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 ]),
223221)
224- def test_mean_over_axis_0_f_contig_output (
222+ def test_mean_sum_over_axis_0_f_contig_output (
225223 func , device , input_type , output_type
226224):
227225 skip_unsupported (device , input_type )
@@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
230228 height = 1
231229 width = 10
232230
233- input = dpt .empty ((height , 10 ), dtype = input_type , device = device )
234- output = dpt .empty (20 , dtype = output_type , device = device )[::2 ]
231+ input = dpt .empty ((height , width ), dtype = input_type , device = device )
232+ output = dpt .empty (width * 2 , dtype = output_type , device = device )[::2 ]
233+
234+ assert func (input , output ) is None
235+
236+
237+ @pytest .mark .parametrize (
238+ "func, device, input_type, output_type" ,
239+ product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 , dpt .float64 ]),
240+ )
241+ def test_mean_sum_over_axis_0_big_output (func , device , input_type , output_type ):
242+ skip_unsupported (device , input_type )
243+ skip_unsupported (device , output_type )
244+
245+ local_mem_size = device .local_mem_size
246+ height = 1
247+ width = 1 + local_mem_size // output_type .itemsize
248+
249+ input = dpt .empty ((height , width ), dtype = input_type , device = device )
250+ output = dpt .empty (width , dtype = output_type , device = device )
235251
236- if func (input , output ):
237- print (output_type )
238252 assert func (input , output ) is None
0 commit comments