diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index a6a73b7d46..6907b6dc8f 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -9509,6 +9509,7 @@ def interpolate_sample_generator(op, device, dtype, requires_grad, **kwargs): a_shape = b + c + spatial_dims yield SampleInput(make(a_shape), size=size) + yield SampleInput(make(a_shape), size=size, mode="nearest-exact") # Test scale/scale_factor passed as a scalar yield SampleInput(make(1, 1, 5, 5), scale_factor=0.5) @@ -9569,6 +9570,11 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): RuntimeError, f"scale_factor(.*?) is expected to be (.*?) a sequence of strictly positive floating point numbers", ) + yield ( + SampleInput(make(1, 1, 1, 1), mode="bilinear"), + RuntimeError, + f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got mode=(.*?)", + ) interpolate_opinfo = OpInfo( diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index 23f4477342..4afd7783bd 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -219,6 +219,31 @@ def test_core_vs_numpy_consistency(op: OpInfo, device: str, dtype: dtypes.dtype, return result +def test_interpolate_nearest_vs_nearest_exact(): + t0 = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=torch.float16) + + def foo(mode): + t1 = torch.nn.functional.interpolate( + t0, + scale_factor=1.5, + mode=mode, + ) + return t1 + + def bar(mode): + t1 = torch.nn.functional.interpolate( + t0, + size=4, + mode=mode, + ) + return t1 + + tfoo = thunder.jit(foo) + assert not torch.equal(tfoo("nearest"), tfoo("nearest-exact")) + tbar = thunder.jit(bar) + assert not torch.equal(tbar("nearest"), tbar("nearest-exact")) + + def test_notimplemented_interpolate_align(): def foo(): t0 = torch.randn((22, 288, 15, 20), dtype=torch.float16) @@ -257,6 +282,23 @@ def foo(): tfoo() +def test_notimplemented_interpolate_modes(): + def foo(mode): + t0 = torch.randn((22, 288, 15, 20), dtype=torch.float16) + t1 = torch.nn.functional.interpolate( + t0, + scale_factor=2.0, + mode=mode, + ) + return t1 + + tfoo = thunder.jit(foo) + for mode in ["linear", "bilinear", "bicubic", "trilinear", "area"]: + match = f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got mode='{mode}'" + with pytest.raises(NotImplementedError, match=match): + tfoo(mode) + + @pytest.mark.parametrize("requires_grad", (True, False)) def test_setitem(requires_grad): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c9fad871d3..305bccb177 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5278,7 +5278,10 @@ def _interpolate_scale_factor_helper( scale_factor: Sequence[float] | float, mode: str = "nearest", ) -> TensorLike: - assert mode == "nearest" + if mode not in ("nearest", "nearest-exact"): + raise ValueError( + f"_interpolate_scale_factor_helper expected mode to be 'nearest' or 'nearest-exact', but got {mode}" + ) # a is assumed to be at least 3D. batch, channels, *spatial_dims = a.shape @@ -5299,13 +5302,23 @@ def _interpolate_scale_factor_helper( ) # perform nearest up/down-sampling - def nearest_sampler(t, input_dim, output_dim, *, scale, dim): + def nearest_sampler( + t: TensorLike, + input_dim: int, + output_dim: int, + *, + scale: float, + dim: int, + exact: bool, + ) -> TensorLike: # It is expected that output_dim = int(input_dim * scale). # Indices [0, ..., output_dim - 1] are mapped to [0, ..., input_dim - 1] - # with the rule i -> int(i * scale) + # with the rule i -> int(i * scale) or i -> round((i + 0.5) * scale - 0.5), + # corresponding to modes 'nearest' or 'nearest-exact', respectively. # Values at these indices is the result. + # References https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/UpSample.h selected_idx = arange(0, output_dim, device=a.device) - selected_idx = to(selected_idx * scale, selected_idx.dtype) + selected_idx = to((selected_idx + exact * 0.5) * scale, selected_idx.dtype) return clang.take(t, selected_idx, dim=dim) def dim_expander(t, dim, n_repeats): @@ -5327,6 +5340,7 @@ def dim_expander(t, dim, n_repeats): # dimenions corresponding to batches and channels. curr_dim = 2 + (len(spatial_dims) - k - 1) + exact: bool = mode == "nearest-exact" if output_dim <= input_dim: if output_dim <= input_dim // 2: # scale_factor <= 1 (i.e. output_dim <= input_dim) implies simple slice @@ -5336,7 +5350,7 @@ def dim_expander(t, dim, n_repeats): a = clang.slice_in_dim(a, 0, end, stride=stride, dim=curr_dim) else: # In this case slice will not do and explicit downsample is needed. - a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim) + a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim, exact=exact) else: if output_dim % input_dim == 0: # In this case we can just expand dim. @@ -5344,7 +5358,7 @@ def dim_expander(t, dim, n_repeats): a = dim_expander(a, curr_dim, n_repeats) else: # In this case expand will not cut it and explicit upsampling is needed. - a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim) + a = nearest_sampler(a, input_dim, output_dim, scale=1.0 / scale, dim=curr_dim, exact=exact) output_shape = [batch, channels] + res_output_spatial_dims[::-1] return reshape(a, output_shape) @@ -5374,7 +5388,7 @@ def _interpolate_size_helper( scale_factor = tuple(output_size / input_size for output_size, input_size in zip(size, spatial_dims)) - return _interpolate_scale_factor_helper(a, scale_factor) + return _interpolate_scale_factor_helper(a, scale_factor, mode) # TODO Implement additional modes and parameters @@ -5390,8 +5404,8 @@ def interpolate( antialias: bool = False, ) -> TensorLike: utils.check( - mode == "nearest", - lambda: f"only mode='nearest' is supported at the moment, but got {mode=}", + (mode == "nearest" or mode == "nearest-exact"), + lambda: f"only modes 'nearest' and 'nearest-exact' are supported at the moment, but got {mode=}", exception_type=NotImplementedError, )