Skip to content

Commit 8e6351b

Browse files
committed
Fixed broken cuda vs cpu tests
1 parent d3c5751 commit 8e6351b

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

test/transforms_v2_kernel_infos.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,20 @@ def sample_inputs_resize_image_tensor():
257257

258258
for image_loader, interpolation in itertools.product(
259259
make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
260-
[
261-
F.InterpolationMode.NEAREST,
262-
F.InterpolationMode.BILINEAR,
263-
F.InterpolationMode.BICUBIC,
264-
],
260+
[F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
265261
):
266262
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)
267263

268264
yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)
269265

270266

267+
def sample_inputs_resize_image_tensor_bicubic():
268+
for image_loader, interpolation in itertools.product(
269+
make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [F.InterpolationMode.BICUBIC]
270+
):
271+
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)
272+
273+
271274
@pil_reference_wrapper
272275
def reference_resize_image_tensor(*args, **kwargs):
273276
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
@@ -364,6 +367,21 @@ def reference_inputs_resize_bounding_box():
364367
xfail_jit_python_scalar_arg("size"),
365368
],
366369
),
370+
KernelInfo(
371+
F.resize_image_tensor,
372+
sample_inputs_fn=sample_inputs_resize_image_tensor_bicubic,
373+
reference_fn=reference_resize_image_tensor,
374+
reference_inputs_fn=reference_inputs_resize_image_tensor,
375+
float32_vs_uint8=True,
376+
closeness_kwargs={
377+
**pil_reference_pixel_difference(10, mae=True),
378+
**cuda_vs_cpu_pixel_difference(atol=30),
379+
**float32_vs_uint8_pixel_difference(1, mae=True),
380+
},
381+
test_marks=[
382+
xfail_jit_python_scalar_arg("size"),
383+
],
384+
),
367385
KernelInfo(
368386
F.resize_bounding_box,
369387
sample_inputs_fn=sample_inputs_resize_bounding_box,

0 commit comments

Comments
 (0)