diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 758d87b904..afad831518 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4390,7 +4390,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list) + idx = op.Reshape(idx, reshape_list, allowzero=True) idx = op.Expand(idx, values_shape) # Flatten the index to 1D and unsqueeze to form a column vector. @@ -4547,7 +4547,7 @@ def aten_instance_norm( momentum=1.0 - momentum, training_mode=False, ) - return op.Reshape(norm, op.Shape(input)) + return op.Reshape(norm, op.Shape(input), allowzero=True) def aten_int_repr(self: TensorType) -> TensorType: @@ -6244,7 +6244,7 @@ def _aten_native_group_norm_onnx( input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps ) # Reshape back to input's shape - norm = op.Reshape(norm, op.Shape(input)) + norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy input_rank = Rank(input) @@ -6693,7 +6693,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) - return op.Reshape(depth_to_space, output_shape) + return op.Reshape(depth_to_space, output_shape, allowzero=True) @torch_op("aten::pixel_unshuffle") @@ -6709,7 +6709,7 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) - return op.Reshape(space_to_depth, output_shape) + return op.Reshape(space_to_depth, output_shape, allowzero=True) def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType: @@ -8390,7 +8390,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor: exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) - self = op.Reshape(self, self_final_shape) + self = op.Reshape(self, self_final_shape, allowzero=True) return op.Tile(self, dims) @@ -8630,7 +8630,7 @@ def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]): final_shape = op.Concat(head_part_rank, *sizes, axis=0) else: final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0) - return op.Reshape(self, final_shape) + return op.Reshape(self, final_shape, allowzero=True) @torch_op("aten::unfold", trace_only=True) @@ -8706,11 +8706,11 @@ def aten__unique( unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True) input_size = op.Shape(self) if return_inverse: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: input_numel = op.ReduceProd(input_size, keepdims=False) if input_numel == 0: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: inverse_indices = op.ConstantOfShape([0]) inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) @@ -8729,11 +8729,11 @@ def aten__unique2( unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True) input_size = op.Shape(self) if return_inverse: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: input_numel = op.ReduceProd(input_size, keepdims=False) if input_numel == 0: - inverse_indices = op.Reshape(inverse_indices, input_size) + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) else: inverse_indices = op.ConstantOfShape([0]) inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) @@ -9019,7 +9019,7 @@ def aten_view(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - return op.Reshape(self, size) + return op.Reshape(self, size, allowzero=True) @torch_op(("aten::view", "aten::_unsafe_view"), complex=True) @@ -9028,7 +9028,7 @@ def aten_view_complex(self: TTensor, size: IntType) -> TTensor: size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) - return op.Reshape(self, complex_size) + return op.Reshape(self, complex_size, allowzero=True) @torch_op("aten::view_as") @@ -9036,7 +9036,7 @@ def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" size = op.Shape(other) - return op.Reshape(self, size) + return op.Reshape(self, size, allowzero=True) @torch_op("aten::view_as_complex", trace_only=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 18ddc69445..18683101ac 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1457,13 +1457,7 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo( - "unflatten", - core_ops.aten_unflatten, - ).xfail( - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", - ), + TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze),