diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 731d12370d..50f62701dc 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2368,12 +2368,20 @@ def aten_upsample_bilinear2d_backward( raise NotImplementedError() +@torch_op("aten::upsample_linear1d", trace_only=True) def aten_upsample_linear1d( - self: TensorType, output_size: INT64, align_corners: bool, scales: Optional[float] = None -) -> TensorType: + self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None +) -> TReal: """upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor""" - - raise NotImplementedError() + # FIXME(justinchuby): Support when scales is provided and align_corners is False + del scales + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) def aten_upsample_linear1d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 26920953a2..c274df2beb 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1464,9 +1464,7 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(S, rank, False), align_corners ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(L, rank, False), - align_corners, + make_arg(shape(D, rank)), shape(L, rank, False), align_corners ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1513,10 +1511,7 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(L, rank, False), - align_corners, - None, + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1544,6 +1539,46 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_upsample_linear1d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 1 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, scales=4.2 + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2037,6 +2072,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d_vec, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_linear1d", + aten_name="upsample_linear1d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_linear1d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 368a051700..4b5d70c22b 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2132,6 +2132,15 @@ def _where_input_wrangler( nn_ops.aten_upsample_bicubic2d_vec, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_linear1d", + nn_ops.aten_upsample_linear1d, + trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d,