Skip to content
16 changes: 12 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,12 +2368,20 @@ def aten_upsample_bilinear2d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_linear1d", trace_only=True)
def aten_upsample_linear1d(
self: TensorType, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TensorType:
self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TReal:
"""upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor"""

raise NotImplementedError()
# FIXME(justinchuby): Support when scales is provided and align_corners is False
del scales
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
output_size,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)


def aten_upsample_linear1d_backward(
Expand Down
56 changes: 49 additions & 7 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,9 +1464,7 @@ def shape(size, rank, with_batch_channel=True):
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
align_corners,
make_arg(shape(D, rank)), shape(L, rank, False), align_corners
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
Expand Down Expand Up @@ -1513,10 +1511,7 @@ def shape(size, rank, with_batch_channel=True):
make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
align_corners,
None,
make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
Expand Down Expand Up @@ -1544,6 +1539,46 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs_upsample_linear1d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

align_corners_options = (True, False)
rank = 1

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)

for align_corners in align_corners_options:
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(L, rank, False), align_corners
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(L, rank, False), align_corners, scales=4.2
)


class _TestParamsMaxPoolEmptyStrideBase:
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
def __init__(self):
Expand Down Expand Up @@ -2037,6 +2072,13 @@ def __init__(self):
sample_inputs_func=sample_inputs_upsample_2d_vec,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_linear1d",
aten_name="upsample_linear1d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_linear1d,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.max_pool1d_with_indices",
aten_name="max_pool1d_with_indices",
Expand Down
9 changes: 9 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,15 @@ def _where_input_wrangler(
nn_ops.aten_upsample_bicubic2d_vec,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten.upsample_linear1d",
nn_ops.aten_upsample_linear1d,
trace_only=True,
).xfail(
matcher=lambda sample: sample.args[1] is False
and sample.kwargs.get("scales") is not None,
reason="fixme: align_corners=False output mismatch when scales are provided",
),
TorchLibOpInfo(
"nn.functional.upsample_nearest2d",
nn_ops.aten_upsample_nearest2d,
Expand Down