Skip to content
6 changes: 6 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9509,6 +9509,7 @@ def interpolate_sample_generator(op, device, dtype, requires_grad, **kwargs):
a_shape = b + c + spatial_dims

yield SampleInput(make(a_shape), size=size)
yield SampleInput(make(a_shape), size=size, mode="nearest-exact")

# Test scale/scale_factor passed as a scalar
yield SampleInput(make(1, 1, 5, 5), scale_factor=0.5)
Expand Down Expand Up @@ -9569,6 +9570,11 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs):
RuntimeError,
f"scale_factor(.*?) is expected to be (.*?) a sequence of strictly positive floating point numbers",
)
yield (
SampleInput(make(1, 1, 1, 1), mode="bilinear"),
RuntimeError,
f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got mode=(.*?)",
)


interpolate_opinfo = OpInfo(
Expand Down
42 changes: 42 additions & 0 deletions thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def test_core_vs_numpy_consistency(op: OpInfo, device: str, dtype: dtypes.dtype,
return result


def test_interpolate_nearest_vs_nearest_exact():
t0 = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=torch.float16)

def foo(mode):
t1 = torch.nn.functional.interpolate(
t0,
scale_factor=1.5,
mode=mode,
)
return t1

def bar(mode):
t1 = torch.nn.functional.interpolate(
t0,
size=4,
mode=mode,
)
return t1

tfoo = thunder.jit(foo)
assert not torch.equal(tfoo("nearest"), tfoo("nearest-exact"))
tbar = thunder.jit(bar)
assert not torch.equal(tbar("nearest"), tbar("nearest-exact"))


def test_notimplemented_interpolate_align():
def foo():
t0 = torch.randn((22, 288, 15, 20), dtype=torch.float16)
Expand Down Expand Up @@ -257,6 +282,23 @@ def foo():
tfoo()


def test_notimplemented_interpolate_modes():
def foo(mode):
t0 = torch.randn((22, 288, 15, 20), dtype=torch.float16)
t1 = torch.nn.functional.interpolate(
t0,
scale_factor=2.0,
mode=mode,
)
return t1

tfoo = thunder.jit(foo)
for mode in ["linear", "bilinear", "bicubic", "trilinear", "area"]:
match = f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got mode='{mode}'"
with pytest.raises(NotImplementedError, match=match):
tfoo(mode)


@pytest.mark.parametrize("requires_grad", (True, False))
def test_setitem(requires_grad):

Expand Down
32 changes: 23 additions & 9 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5278,7 +5278,10 @@ def _interpolate_scale_factor_helper(
scale_factor: Sequence[float] | float,
mode: str = "nearest",
) -> TensorLike:
assert mode == "nearest"
if mode not in ("nearest", "nearest-exact"):
raise ValueError(
f"_interpolate_scale_factor_helper expected mode to be 'nearest' or 'nearest-exact', but got {mode}"
)

# a is assumed to be at least 3D.
batch, channels, *spatial_dims = a.shape
Expand All @@ -5299,13 +5302,23 @@ def _interpolate_scale_factor_helper(
)

# perform nearest up/down-sampling
def nearest_sampler(t, input_dim, output_dim, *, scale, dim):
def nearest_sampler(
t: TensorLike,
input_dim: int,
output_dim: int,
*,
scale: float,
dim: int,
exact: bool,
) -> TensorLike:
# It is expected that output_dim = int(input_dim * scale).
# Indices [0, ..., output_dim - 1] are mapped to [0, ..., input_dim - 1]
# with the rule i -> int(i * scale)
# with the rule i -> int(i * scale) or i -> round((i + 0.5) * scale - 0.5),
# corresponding to modes 'nearest' or 'nearest-exact', respectively.
# Values at these indices is the result.
# References https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/UpSample.h
selected_idx = arange(0, output_dim, device=a.device)
selected_idx = to(selected_idx * scale, selected_idx.dtype)
selected_idx = to((selected_idx + exact * 0.5) * scale, selected_idx.dtype)
return clang.take(t, selected_idx, dim=dim)

def dim_expander(t, dim, n_repeats):
Expand All @@ -5327,6 +5340,7 @@ def dim_expander(t, dim, n_repeats):
# dimenions corresponding to batches and channels.
curr_dim = 2 + (len(spatial_dims) - k - 1)

exact: bool = mode == "nearest-exact"
if output_dim <= input_dim:
if output_dim <= input_dim // 2:
# scale_factor <= 1 (i.e. output_dim <= input_dim) implies simple slice
Expand All @@ -5336,15 +5350,15 @@ def dim_expander(t, dim, n_repeats):
a = clang.slice_in_dim(a, 0, end, stride=stride, dim=curr_dim)
else:
# In this case slice will not do and explicit downsample is needed.
a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim)
a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim, exact=exact)
else:
if output_dim % input_dim == 0:
# In this case we can just expand dim.
n_repeats = output_dim // input_dim
a = dim_expander(a, curr_dim, n_repeats)
else:
# In this case expand will not cut it and explicit upsampling is needed.
a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim)
a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim, exact=exact)

output_shape = [batch, channels] + res_output_spatial_dims[::-1]
return reshape(a, output_shape)
Expand Down Expand Up @@ -5374,7 +5388,7 @@ def _interpolate_size_helper(

scale_factor = tuple(output_size / input_size for output_size, input_size in zip(size, spatial_dims))

return _interpolate_scale_factor_helper(a, scale_factor)
return _interpolate_scale_factor_helper(a, scale_factor, mode)


# TODO Implement additional modes and parameters
Expand All @@ -5390,8 +5404,8 @@ def interpolate(
antialias: bool = False,
) -> TensorLike:
utils.check(
mode == "nearest",
lambda: f"only mode='nearest' is supported at the moment, but got {mode=}",
(mode == "nearest" or mode == "nearest-exact"),
lambda: f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got {mode=}",
exception_type=NotImplementedError,
)

Expand Down
Loading