diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cdc982bbd8..ea43c2c4db 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2991,8 +2991,8 @@ def _aten_embedding_bag_onnx( indices_1d = op.Reshape(indices, neg_1) # Get weight out according to indices_1d, new_weight = op.Gather(weight, indices_1d) - # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) - new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1)) + # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) + new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=[1])) weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1) indices_size = op.Shape(indices_1d) @@ -3131,8 +3131,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx( # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] indices_weight = op.Gather(weight, indices) - # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) - indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1)) + # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) + indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=[1])) # The element in sequence must be FLOAT32 dtype due to ORT bug indices_weight = op.Cast(indices_weight, to=FLOAT.dtype) @@ -4145,7 +4145,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64: return op.Shape(broadcasted) -@torch_op("aten::index.Tensor", private=True, trace_only=True) def _aten_index_onnx( self: TensorType, indices: Sequence[Optional[INT64]], @@ -4173,7 +4172,7 @@ def _aten_index_onnx( not_none_indices = [idx for idx in indices if idx is not None] broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) final_index = op.Concat( - *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), + *(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices), axis=-1, ) @@ -7706,13 +7705,13 @@ def aten_select_backward( raise NotImplementedError() -@torch_op("aten::select_scatter") +@torch_op("aten::select_scatter", trace_only=True) def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType: """select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor""" # Change src rank to self rank according to dim # e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4] - update = op.Unsqueeze(src, axes=dim) + update = op.Unsqueeze(src, axes=[dim]) # Change index rank to the same as 'update' [2,1,4] indices = op.Expand(index, op.Shape(update)) return op.ScatterElements(self, indices, update, axis=dim, reduction="none") @@ -7880,7 +7879,7 @@ def aten_slice_scatter( zero, op.Unsqueeze(step, zero), ) - index_base = op.Unsqueeze(index_base, -1) + index_base = op.Unsqueeze(index_base, [-1]) # Use trace only to construct the perm attribute in Transpose dims = None @@ -8623,7 +8622,7 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor: self_rank = len(self.shape) if self_rank == 0: - result = op.Unsqueeze(self, 0) + result = op.Unsqueeze(self, [0]) else: # Handle negative dimension if dimension < 0: @@ -8792,8 +8791,7 @@ def aten_unsafe_split_with_sizes( def aten_unsqueeze(self: TTensor, dim: int) -> TTensor: """unsqueeze(Tensor(a) self, int dim) -> Tensor(a)""" - dim = op.Cast(dim, to=INT64.dtype) - return op.Unsqueeze(self, dim) + return op.Unsqueeze(self, [dim]) def aten_unsqueeze_copy(self: TensorType, dim: int) -> TensorType: diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 20127cec88..34f143b4ee 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1002,7 +1002,7 @@ def _aten_max_pool_onnx( ) -> TFloatOrUInt8: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) pool_result, _ = op.MaxPool( self, @@ -1014,7 +1014,7 @@ def _aten_max_pool_onnx( ) if self_rank_is_unbatched_rank: - pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) + pool_result = op.Squeeze(pool_result, [0]) return pool_result @@ -1136,7 +1136,7 @@ def _aten_max_pool_with_indices_onnx( ) -> Tuple[TFloatOrUInt8, INT64]: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: - self = op.Unsqueeze(self, axes=0) + self = op.Unsqueeze(self, axes=[0]) pool_result, indices = op.MaxPool( self, @@ -1191,8 +1191,8 @@ def _aten_max_pool_with_indices_onnx( indices = op.Sub(indices, delta) if self_rank_is_unbatched_rank: - pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) - indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + pool_result = op.Squeeze(pool_result, [0]) + indices = op.Squeeze(indices, [0]) return (pool_result, indices) @@ -1365,11 +1365,11 @@ def aten_nll_loss( self_rank_is_1 = Rank(self) == 1 if self_rank_is_1: # self rank should be at least 2 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) rank_target = Rank(target) if rank_target == 0: # target rank should be at least 1 - target = op.Unsqueeze(target, op.Constant(value_ints=[0])) + target = op.Unsqueeze(target, [0]) if reduction == 0: reduction_str = "none" diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 6a7f465885..1b123394d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -219,7 +219,7 @@ def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: self_is_scalar = len(self.shape) == 0 if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) result = op.LogSoftmax(self, axis=dim) if dtype != -1: result = op.Cast(result, to=dtype)