From 26c4d80308d04c06867384b0476f6477f0dd7daf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 10:47:57 -0700 Subject: [PATCH 1/3] [torchlib] Set allowzero=True on Reshape where appropriate --- .../function_libs/torch_lib/ops/core.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 758d87b904..a01bb249d4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4390,7 +4390,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list) + idx = op.Reshape(idx, reshape_list, allowzero=True) idx = op.Expand(idx, values_shape) # Flatten the index to 1D and unsqueeze to form a column vector. @@ -4531,6 +4531,7 @@ def aten_instance_norm( bn_input = op.Reshape( input, op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), + allowzero=True ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) @@ -4547,7 +4548,7 @@ def aten_instance_norm( momentum=1.0 - momentum, training_mode=False, ) - return op.Reshape(norm, op.Shape(input)) + return op.Reshape(norm, op.Shape(input), allowzero=True) def aten_int_repr(self: TensorType) -> TensorType: @@ -6237,14 +6238,14 @@ def _aten_native_group_norm_onnx( group_tensor = op.Reshape(group, neg_1) # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) - input_reshaped = op.Reshape(input, shape_input) + input_reshaped = op.Reshape(input, shape_input, allowzero=True) weight_inst_norm = op.Expand(op.CastLike(1.0, input), group_tensor) bias_inst_norm = op.Expand(op.CastLike(0.0, input), group_tensor) norm = op.InstanceNormalization( input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps ) # Reshape back to input's shape - norm = op.Reshape(norm, op.Shape(input)) + norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy input_rank = Rank(input) @@ -6259,7 +6260,7 @@ def _aten_native_group_norm_onnx( # The returned shape for mean and vstd should be [N, group, -1] N = op.Shape(input, start=0, end=1) shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0) - input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) + input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1, allowzero=True) # The output size is [N, group], so dims = [2] axes = op.Constant(value_ints=[2]) # Get mean which size is [N, group, 1], for broadcasting @@ -6693,7 +6694,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) - return op.Reshape(depth_to_space, output_shape) + return op.Reshape(depth_to_space, output_shape, allowzero=True) @torch_op("aten::pixel_unshuffle") @@ -6709,7 +6710,7 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) - return op.Reshape(space_to_depth, output_shape) + return op.Reshape(space_to_depth, output_shape, allowzero=True) def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType: @@ -8390,7 +8391,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor: exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) - self = op.Reshape(self, self_final_shape) + self = op.Reshape(self, self_final_shape, allowzero=True) return op.Tile(self, dims) @@ -8630,7 +8631,7 @@ def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]): final_shape = op.Concat(head_part_rank, *sizes, axis=0) else: final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0) - return op.Reshape(self, final_shape) + return op.Reshape(self, final_shape, allowzero=True) @torch_op("aten::unfold", trace_only=True) @@ -8706,11 +8707,11 @@ def aten__unique( unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True) input_size = op.Shape(self) if return_inverse: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: input_numel = op.ReduceProd(input_size, keepdims=False) if input_numel == 0: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: inverse_indices = op.ConstantOfShape([0]) inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) @@ -8729,11 +8730,11 @@ def aten__unique2( unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True) input_size = op.Shape(self) if return_inverse: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: input_numel = op.ReduceProd(input_size, keepdims=False) if input_numel == 0: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: inverse_indices = op.ConstantOfShape([0]) inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) @@ -9019,7 +9020,7 @@ def aten_view(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - return op.Reshape(self, size) + return op.Reshape(self, size, allowzero=True) @torch_op(("aten::view", "aten::_unsafe_view"), complex=True) @@ -9028,7 +9029,7 @@ def aten_view_complex(self: TTensor, size: IntType) -> TTensor: size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) - return op.Reshape(self, complex_size) + return op.Reshape(self, complex_size, allowzero=True) @torch_op("aten::view_as") @@ -9036,7 +9037,7 @@ def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" size = op.Shape(other) - return op.Reshape(self, size) + return op.Reshape(self, size, allowzero=True) @torch_op("aten::view_as_complex", trace_only=True) From ecf5620894f60030e92eee42e0646481f332d8c8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 11:00:04 -0700 Subject: [PATCH 2/3] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a01bb249d4..90e468dad1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4531,7 +4531,7 @@ def aten_instance_norm( bn_input = op.Reshape( input, op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), - allowzero=True + allowzero=True, ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) From 9a415cdb1d9aa51158344e253728d2452cbc0c46 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 11:22:55 -0700 Subject: [PATCH 3/3] Tests --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- tests/function_libs/torch_lib/ops_test_data.py | 8 +------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 90e468dad1..afad831518 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4531,7 +4531,6 @@ def aten_instance_norm( bn_input = op.Reshape( input, op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), - allowzero=True, ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) @@ -6238,7 +6237,7 @@ def _aten_native_group_norm_onnx( group_tensor = op.Reshape(group, neg_1) # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) - input_reshaped = op.Reshape(input, shape_input, allowzero=True) + input_reshaped = op.Reshape(input, shape_input) weight_inst_norm = op.Expand(op.CastLike(1.0, input), group_tensor) bias_inst_norm = op.Expand(op.CastLike(0.0, input), group_tensor) norm = op.InstanceNormalization( @@ -6260,7 +6259,7 @@ def _aten_native_group_norm_onnx( # The returned shape for mean and vstd should be [N, group, -1] N = op.Shape(input, start=0, end=1) shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0) - input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1, allowzero=True) + input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) # The output size is [N, group], so dims = [2] axes = op.Constant(value_ints=[2]) # Get mean which size is [N, group, 1], for broadcasting diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 18ddc69445..18683101ac 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1457,13 +1457,7 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo( - "unflatten", - core_ops.aten_unflatten, - ).xfail( - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", - ), + TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze),