diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 08965d085db..6ed3146ff85 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -14,8 +14,8 @@ "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior # such as 0 to negative power. "_segment_reduce", - "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. + "_upsample_bilinear2d_aa", # test passing scales_h, scales_w is failing. "byte", "cat", "cauchy", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 9ea94bc46c4..62bab1927a0 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,9 +1,10 @@ """Torch ops implemented using jax.""" import sys -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple, Union import functools +import math import jax from jax import numpy as jnp import functools @@ -4161,3 +4162,68 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): return output +@op(torch.ops.aten._upsample_bilinear2d_aa) +def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors=None, scales_h=None, scales_w=None): + # input: is of type jaxlib.xla_extension.ArrayImpl + image = input + method = "bilinear" + antialias = True # ignored for upsampling + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html + # Resize does not distinguish batch, channel size. + # We need to leave them as is + # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions + # pytorch image shape is (C,H,W) or (N,C,H,W) + # N - batch size + # C - no of channels + # H,W - heigth, width + + shape = list(image.shape) + # overriding output_size + if scale_factors: + shape[-1] = int(math.floor(shape[-1]*scale_factors[-1])) + shape[-2] = int(math.floor(shape[-2]*scale_factors[-2])) + if scales_h: + shape[-2] = int(math.floor(shape[-2]*scales_h)) + if scales_w: + shape[-1] = int(math.floor(shape[-1]*scales_w)) + # output_size overrides scale_factors, scales_* + if output_size: + shape[-1] = output_size[-1] + shape[-2] = output_size[-2] + + # align_corners is not supported in resize() + # https://github.com/jax-ml/jax/issues/11206 + if align_corners: + return resize_with_aligned_corners2d(image, shape, scale_factors, method, antialias=True) + return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST + +# From: https://github.com/jax-ml/jax/issues/11206 +def resize_with_aligned_corners2d( + image: jax.Array, + shape: Tuple[int, ...], + scale: Tuple[int, ...], + method: Union[str, jax.image.ResizeMethod], + antialias: bool, +): + """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's + interpolation functions.""" + + + spatial_dims = (2,3) + if len(shape) == 3: + spatial_dims = (1,2) + + scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims]) + #translation = (scale / 2.0 - 0.5) + translation = (scale * 0.0 ) + + return jax.image.scale_and_translate( + image, + shape, + method=method, + scale=scale, + spatial_dims=spatial_dims, + translation=translation, + antialias=antialias, + ) \ No newline at end of file