Skip to content

Commit

Permalink
Add back the implementation of adaptive_pool_2d (#8526)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Jan 3, 2025
1 parent bc3bf1f commit 0ca733b
Showing 1 changed file with 124 additions and 31 deletions.
155 changes: 124 additions & 31 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Torch ops implemented using jax."""

import sys
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, Callable
import functools

import math
Expand Down Expand Up @@ -140,7 +140,8 @@ def _aten_clone(x, memory_format=None):
# aten.trunc
@op(torch.ops.aten.trunc)
def _aten_trunc(x):
return jnp.trunc(x)
res = jnp.trunc(x)
return res.astype(x)


@op(torch.ops.aten.index_copy)
Expand Down Expand Up @@ -1605,7 +1606,7 @@ def _aten_tanh(self):
# aten.ceil
@op(torch.ops.aten.ceil)
def _aten_ceil(self):
return jnp.ceil(self)
return jnp.ceil(self).astype(self)


# aten.asin
Expand Down Expand Up @@ -1888,36 +1889,128 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
y = jnp.squeeze(y, axis=0)
return y


@op(torch.ops.aten._adaptive_avg_pool2d)
@op(torch.ops.aten._adaptive_avg_pool3d)
def _aten_adaptive_avg_pool3d(x, output_shape):
assert len(x.shape) in (4,5), f'Expected 4D or 5D input but got {len(x.shape)} dimensions'
assert len(output_shape) == 3, f'Expected 3D output but got {len(output_shape)} dimensions'

# Reference PyTorch implementation:
# https://github.com/pytorch/pytorch/blob/ef4475f9025b3c46a13bdd054b6adfbcb5f8ab8c/aten/src/ATen/native/AdaptiveAveragePooling.cpp
output_shape = x.shape[:-3] + tuple(output_shape)
output = jnp.zeros(output_shape, dtype = x.dtype)
stride_d = x.shape[-3] / output_shape[-3]
stride_h = x.shape[-2] / output_shape[-2]
stride_w = x.shape[-1] / output_shape[-1]

def avg_pool_batch(d, h, w):
start_d = int(jnp.floor(d * stride_d))
end_d = int(jnp.ceil((d+1) * stride_d))
start_h = int(jnp.floor(h * stride_h))
end_h = int(jnp.ceil((h+1) * stride_h))
start_w = int(jnp.floor(w * stride_w))
end_w = int(jnp.ceil((w+1) * stride_w))
return jnp.mean(x[..., start_d:end_d, start_h:end_h, start_w:end_w], axis=(-3, -2, -1))

# TODO: Replace this with more performant implementation.
# Related JAX issue requiring adaptive pooling: https://github.com/jax-ml/jax/issues/20098
for d in range(output_shape[-3]):
for h in range(output_shape[-2]):
for w in range(output_shape[-1]):
output = output.at[..., d, h, w].set(avg_pool_batch(d, h, w))
return output
def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) -> jnp.ndarray:
"""
Applies a 2/3D adaptive average pooling over an input signal composed of several input planes.
See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
Args:
input: input tensor
output_size: the target output size (single integer or double-integer tuple)
Context:
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401
"""
shape = input.shape
ndim = len(shape)
out_dim = len(output_size)
num_spatial_dim = ndim - out_dim

# Preconditions

assert ndim in (out_dim + 1, out_dim + 2), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
for d in input.shape[-2:]:
assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
f"non-batch dimensions, but input has shape {tuple(shape)}."

# Optimisation (we should also do this in the kernel implementation)
if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
kernel = tuple(i - (o - 1) * s for i, o, s in zip(shape[-out_dim:], output_size, stride))
return _aten_avg_pool(
input,
kernel,
strides=stride,
)

def start_index(a, b, c):
return (a * c) // b

def end_index(a, b, c):
return ((a + 1) * c + b - 1) // b

def compute_idx(in_size, out_size):
orange = jnp.arange(out_size, dtype=jnp.int64)
i0 = start_index(orange, out_size, in_size)
# Let length = end_index - start_index, i.e. the length of the pooling kernels
# length.max() can be computed analytically as follows:
maxlength = in_size // out_size + 1
in_size_mod = in_size % out_size
# adaptive = True iff there are kernels with different lengths
adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
if adaptive:
maxlength += 1
elif in_size_mod == 0:
maxlength -= 1

range_max = jnp.arange(maxlength, dtype=jnp.int64)
idx = i0[:, None] + range_max
if adaptive:
# Need to clamp to avoid accessing out-of-bounds memory
idx = jnp.minimum(idx, in_size - 1)

# Compute the length
i1 = end_index(orange, out_size, in_size)
length = i1 - i0
else:
length = maxlength
return idx, length, range_max, adaptive

idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
# length is not None if it's constant, otherwise we'll need to compute it
for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)

def _unsqueeze_to_dim(x, dim):
ndim = len(x.shape)
return jax.lax.expand_dims(x, tuple(range(ndim, dim)))

if out_dim == 2:
# NOTE: unsqueeze to insert extra 1 in ranks; so they
# would broadcast
vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
reduce_axis = (-3, -1)
else:
assert out_dim == 3
vals = input[..., _unsqueeze_to_dim(idx[0], 6),
_unsqueeze_to_dim(idx[1], 4),
idx[2]]
reduce_axis = (-5, -3, -1)

# Shortcut for the simpler case
if not any(adaptive):
return jnp.mean(vals, axis=reduce_axis)

def maybe_mask(vals, length, range_max, adaptive, dim):
if isinstance(length, int):
return vals, length
else:
# zero-out the things we didn't really want to select
assert dim < 0
# hack
mask = range_max >= length[:, None]
if dim == -2:
mask = _unsqueeze_to_dim(mask, 4)
elif dim == -3:
mask = _unsqueeze_to_dim(mask, 6)
vals = jnp.where(mask, 0.0, vals)
# Compute the length of each window
length = _unsqueeze_to_dim(length, -dim)
return vals, length

for i in range(len(length)):
vals, length[i] = maybe_mask(vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))

# We unroll the sum as we assume that the kernels are going to be small
ret = jnp.sum(vals, axis=reduce_axis)
# NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
# this is multiplication with broadcasting, not regular pointwise product
return ret / math.prod(length)


@op(torch.ops.aten.avg_pool1d)
@op(torch.ops.aten.avg_pool2d)
Expand Down

0 comments on commit 0ca733b

Please sign in to comment.