From 8b9da4ee4464e31241f220f9ed7d58ee4ce7bf13 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:12:33 -0700 Subject: [PATCH 1/8] [torchlib] Improve pixel_shuffle Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e950699aca..1533ad172d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6691,13 +6691,17 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -@torch_op("aten::pixel_shuffle") +@torch_op("aten::pixel_shuffle", trace_only=True) def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") + + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) self_shape = op.Shape(self) batch_dims = self_shape[:-3] chw_in_dims = self_shape[-3:] - # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) @@ -6709,11 +6713,14 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: @torch_op("aten::pixel_unshuffle") def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.SpaceToDepth(self, blocksize=downscale_factor) + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) self_shape = op.Shape(self) batch_dims = self_shape[:-3] chw_in_dims = self_shape[-3:] - # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) From 42b03c03bd45efede74696958ae89180be43ac4c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:16:59 -0700 Subject: [PATCH 2/8] fix Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1533ad172d..01ee4b22e0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6699,14 +6699,15 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] + batch_dims = op.Slice(self_shape, starts=[0], ends=[-3], axes=[0]) + chw_in_dims = op.Slice(self_shape, starts=[-3], ends=[_INT64_MAX], axes=[0]) reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) + final_dims = op.Slice(op.Shape(depth_to_space), starts=[1], ends=[_INT64_MAX], axes=[0]) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) @@ -6718,14 +6719,15 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] + batch_dims = op.Slice(self_shape, starts=[0], ends=[-3], axes=[0]) + chw_in_dims = op.Slice(self_shape, starts=[-3], ends=[_INT64_MAX], axes=[0]) reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) + final_dims = op.Slice(op.Shape(depth_to_space), starts=[1], ends=[_INT64_MAX], axes=[0]) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) From 0203af027acfa7d7cbf08e66b48c5fd56228760b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:17:29 -0700 Subject: [PATCH 3/8] trace_only Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 01ee4b22e0..332752d927 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6711,7 +6711,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle") +@torch_op("aten::pixel_unshuffle" trace_only=True) def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" if len(self.shape) == 4: From 9c2d5cd6af63471cbf910ce8641dc5c9b5275234 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:24:58 -0700 Subject: [PATCH 4/8] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 332752d927..1a209ddc84 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6711,7 +6711,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle" trace_only=True) +@torch_op("aten::pixel_unshuffle", trace_only=True) def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" if len(self.shape) == 4: From a7dce4464e90a8cc180ad290a3b8b2deafa3f39f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:25:11 -0700 Subject: [PATCH 5/8] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a209ddc84..45040978d2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6726,7 +6726,7 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - final_dims = op.Slice(op.Shape(depth_to_space), starts=[1], ends=[_INT64_MAX], axes=[0]) + final_dims = op.Slice(op.Shape(space_to_depth), starts=[1], ends=[_INT64_MAX], axes=[0]) output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) From 513ec494c026a58629eb419b6192d0bdfa42ae3a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:39:39 -0700 Subject: [PATCH 6/8] test Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 01db7161b5..10ed8a9423 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1088,10 +1088,6 @@ def _where_input_wrangler( .xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", @@ -1100,10 +1096,6 @@ def _where_input_wrangler( .xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "ops.aten.reflection_pad1d", From 3f90ef70da119b5d158b6e63d3dbafcc4cdca2a9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 13:03:52 -0700 Subject: [PATCH 7/8] format Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 10ed8a9423..646a5133fa 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1084,16 +1084,14 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.pixel_shuffle", core_ops.aten_pixel_shuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", core_ops.aten_pixel_unshuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", ), From 908ca91fef4c5b6fee1156fccf8ea6b6fbb20502 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 15:22:44 -0700 Subject: [PATCH 8/8] use shape Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 45040978d2..8bb1665aaf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6698,15 +6698,14 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) - self_shape = op.Shape(self) - batch_dims = op.Slice(self_shape, starts=[0], ends=[-3], axes=[0]) - chw_in_dims = op.Slice(self_shape, starts=[-3], ends=[_INT64_MAX], axes=[0]) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - final_dims = op.Slice(op.Shape(depth_to_space), starts=[1], ends=[_INT64_MAX], axes=[0]) + final_dims = op.Shape(depth_to_space, start=1) output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) @@ -6718,15 +6717,14 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: return op.SpaceToDepth(self, blocksize=downscale_factor) # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) - self_shape = op.Shape(self) - batch_dims = op.Slice(self_shape, starts=[0], ends=[-3], axes=[0]) - chw_in_dims = op.Slice(self_shape, starts=[-3], ends=[_INT64_MAX], axes=[0]) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - final_dims = op.Slice(op.Shape(space_to_depth), starts=[1], ends=[_INT64_MAX], axes=[0]) + final_dims = op.Shape(space_to_depth, start=1) output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True)