diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e950699aca..8bb1665aaf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6691,34 +6691,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -@torch_op("aten::pixel_shuffle") +@torch_op("aten::pixel_shuffle", trace_only=True) def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] + if len(self.shape) == 4: + return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) + final_dims = op.Shape(depth_to_space, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle") +@torch_op("aten::pixel_unshuffle", trace_only=True) def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.SpaceToDepth(self, blocksize=downscale_factor) - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) + final_dims = op.Shape(space_to_depth, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 01db7161b5..646a5133fa 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1084,26 +1084,16 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.pixel_shuffle", core_ops.aten_pixel_shuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", core_ops.aten_pixel_unshuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "ops.aten.reflection_pad1d",