Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6691,34 +6691,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType:
raise NotImplementedError()


@torch_op("aten::pixel_shuffle")
@torch_op("aten::pixel_shuffle", trace_only=True)
def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
"""pixel_shuffle(Tensor self, int upscale_factor) -> Tensor"""
self_shape = op.Shape(self)
batch_dims = self_shape[:-3]
chw_in_dims = self_shape[-3:]
if len(self.shape) == 4:
return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD")

# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
batch_dims = op.Shape(self, end=-3)
chw_in_dims = op.Shape(self, start=-3)

reshaped_self = op.Reshape(
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
)
depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0)
final_dims = op.Shape(depth_to_space, start=1)
output_shape = op.Concat(batch_dims, final_dims, axis=0)
return op.Reshape(depth_to_space, output_shape, allowzero=True)


@torch_op("aten::pixel_unshuffle")
@torch_op("aten::pixel_unshuffle", trace_only=True)
def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
"""pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor"""
if len(self.shape) == 4:
return op.SpaceToDepth(self, blocksize=downscale_factor)

self_shape = op.Shape(self)
batch_dims = self_shape[:-3]
chw_in_dims = self_shape[-3:]
# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
batch_dims = op.Shape(self, end=-3)
chw_in_dims = op.Shape(self, start=-3)

reshaped_self = op.Reshape(
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
)
space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0)
final_dims = op.Shape(space_to_depth, start=1)
output_shape = op.Concat(batch_dims, final_dims, axis=0)
return op.Reshape(space_to_depth, output_shape, allowzero=True)


Expand Down
14 changes: 2 additions & 12 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,26 +1084,16 @@ def _where_input_wrangler(
TorchLibOpInfo(
"nn.functional.pixel_shuffle",
core_ops.aten_pixel_shuffle,
)
.xfail(
).xfail(
dtypes=(torch.int32, torch.int64),
reason="fixme: ONNX Runtime does not support int32/64 inputs",
)
.xfail(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: ORT does not support empty tensor as input",
),
TorchLibOpInfo(
"nn.functional.pixel_unshuffle",
core_ops.aten_pixel_unshuffle,
)
.xfail(
).xfail(
dtypes=(torch.int32, torch.int64),
reason="fixme: ONNX Runtime does not support int32/64 inputs",
)
.xfail(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: ORT does not support empty tensor as input",
),
TorchLibOpInfo(
"ops.aten.reflection_pad1d",
Expand Down
Loading