Skip to content

Commit

Permalink
First step in implementing upsample_bilinear2d_aa (#8090)
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s authored Oct 2, 2024
1 parent b6bcf03 commit 365841d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"__rpow__", # NOTE: cannot fix because torch test case has undefined behavior
# such as 0 to negative power.
"_segment_reduce",
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"_upsample_bilinear2d_aa", # test passing scales_h, scales_w is failing.
"byte",
"cat",
"cauchy",
Expand Down
68 changes: 67 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Torch ops implemented using jax."""

import sys
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple, Union
import functools

import math
import jax
from jax import numpy as jnp
import functools
Expand Down Expand Up @@ -4161,3 +4162,68 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):

return output

@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors=None, scales_h=None, scales_w=None):
# input: is of type jaxlib.xla_extension.ArrayImpl
image = input
method = "bilinear"
antialias = True # ignored for upsampling

# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
# Resize does not distinguish batch, channel size.
# We need to leave them as is
# https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions
# pytorch image shape is (C,H,W) or (N,C,H,W)
# N - batch size
# C - no of channels
# H,W - heigth, width

shape = list(image.shape)
# overriding output_size
if scale_factors:
shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
if scales_h:
shape[-2] = int(math.floor(shape[-2]*scales_h))
if scales_w:
shape[-1] = int(math.floor(shape[-1]*scales_w))
# output_size overrides scale_factors, scales_*
if output_size:
shape[-1] = output_size[-1]
shape[-2] = output_size[-2]

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
return resize_with_aligned_corners2d(image, shape, scale_factors, method, antialias=True)
return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST

# From: https://github.com/jax-ml/jax/issues/11206
def resize_with_aligned_corners2d(
image: jax.Array,
shape: Tuple[int, ...],
scale: Tuple[int, ...],
method: Union[str, jax.image.ResizeMethod],
antialias: bool,
):
"""Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
interpolation functions."""


spatial_dims = (2,3)
if len(shape) == 3:
spatial_dims = (1,2)

scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
#translation = (scale / 2.0 - 0.5)
translation = (scale * 0.0 )

return jax.image.scale_and_translate(
image,
shape,
method=method,
scale=scale,
spatial_dims=spatial_dims,
translation=translation,
antialias=antialias,
)

0 comments on commit 365841d

Please sign in to comment.