diff --git a/.github/workflows/test-m1.yml b/.github/workflows/test-m1.yml index c03fa9f76e4..8b1c6147d15 100644 --- a/.github/workflows/test-m1.yml +++ b/.github/workflows/test-m1.yml @@ -46,5 +46,13 @@ jobs: run: | . ~/miniconda3/etc/profile.d/conda.sh set -ex - conda run -p ${ENV_NAME} --no-capture-output python3 -u -mpytest -v --tb=long --durations 20 + + # Run resize benchmark + echo "--- Run resize benchmark ---" + # wget https://gist.githubusercontent.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2/raw/8d4731c785140356a89c0afe4b012b304c4f1787/check_resize_uint8.py + curl https://gist.githubusercontent.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2/raw/8d4731c785140356a89c0afe4b012b304c4f1787/check_resize_uint8.py -o /tmp/check_resize_uint8.py + conda run -p ${ENV_NAME} python3 -u /tmp/check_resize_uint8.py + echo "--- END Run resize benchmark ---" + + # conda run -p ${ENV_NAME} --no-capture-output python3 -u -mpytest -v --tb=long --durations 20 conda env remove -p ${ENV_NAME} diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index c48250f3b96..75ee90de68a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -183,7 +183,10 @@ def resize_image_tensor( image = image.reshape(-1, num_channels, old_height, old_width) dtype = image.dtype - need_cast = dtype not in (torch.float32, torch.float64) + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation.value in ["nearest", "bilinear"]: + acceptable_dtypes.append(torch.uint8) + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32)