Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2991,8 +2991,8 @@ def _aten_embedding_bag_onnx(
indices_1d = op.Reshape(indices, neg_1)
# Get weight out according to indices_1d,
new_weight = op.Gather(weight, indices_1d)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
# This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=[1]))
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)

Expand Down Expand Up @@ -3131,8 +3131,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))
# This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=[1]))

# The element in sequence must be FLOAT32 dtype due to ORT bug
indices_weight = op.Cast(indices_weight, to=FLOAT.dtype)
Expand Down Expand Up @@ -4145,7 +4145,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
return op.Shape(broadcasted)


@torch_op("aten::index.Tensor", private=True, trace_only=True)
def _aten_index_onnx(
self: TensorType,
indices: Sequence[Optional[INT64]],
Expand Down Expand Up @@ -4173,7 +4172,7 @@ def _aten_index_onnx(
not_none_indices = [idx for idx in indices if idx is not None]
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
final_index = op.Concat(
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices),
axis=-1,
)

Expand Down Expand Up @@ -7706,13 +7705,13 @@ def aten_select_backward(
raise NotImplementedError()


@torch_op("aten::select_scatter")
@torch_op("aten::select_scatter", trace_only=True)
def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType:
"""select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor"""

# Change src rank to self rank according to dim
# e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4]
update = op.Unsqueeze(src, axes=dim)
update = op.Unsqueeze(src, axes=[dim])
# Change index rank to the same as 'update' [2,1,4]
indices = op.Expand(index, op.Shape(update))
return op.ScatterElements(self, indices, update, axis=dim, reduction="none")
Expand Down Expand Up @@ -7880,7 +7879,7 @@ def aten_slice_scatter(
zero,
op.Unsqueeze(step, zero),
)
index_base = op.Unsqueeze(index_base, -1)
index_base = op.Unsqueeze(index_base, [-1])

# Use trace only to construct the perm attribute in Transpose
dims = None
Expand Down Expand Up @@ -8623,7 +8622,7 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:

self_rank = len(self.shape)
if self_rank == 0:
result = op.Unsqueeze(self, 0)
result = op.Unsqueeze(self, [0])
else:
# Handle negative dimension
if dimension < 0:
Expand Down Expand Up @@ -8792,8 +8791,7 @@ def aten_unsafe_split_with_sizes(
def aten_unsqueeze(self: TTensor, dim: int) -> TTensor:
"""unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"""

dim = op.Cast(dim, to=INT64.dtype)
return op.Unsqueeze(self, dim)
return op.Unsqueeze(self, [dim])


def aten_unsqueeze_copy(self: TensorType, dim: int) -> TensorType:
Expand Down
14 changes: 7 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def _aten_max_pool_onnx(
) -> TFloatOrUInt8:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
self = op.Unsqueeze(self, [0])

pool_result, _ = op.MaxPool(
self,
Expand All @@ -1014,7 +1014,7 @@ def _aten_max_pool_onnx(
)

if self_rank_is_unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
pool_result = op.Squeeze(pool_result, [0])

return pool_result

Expand Down Expand Up @@ -1136,7 +1136,7 @@ def _aten_max_pool_with_indices_onnx(
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
if self_rank_is_unbatched_rank:
self = op.Unsqueeze(self, axes=0)
self = op.Unsqueeze(self, axes=[0])

pool_result, indices = op.MaxPool(
self,
Expand Down Expand Up @@ -1191,8 +1191,8 @@ def _aten_max_pool_with_indices_onnx(
indices = op.Sub(indices, delta)

if self_rank_is_unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
pool_result = op.Squeeze(pool_result, [0])
indices = op.Squeeze(indices, [0])

return (pool_result, indices)

Expand Down Expand Up @@ -1365,11 +1365,11 @@ def aten_nll_loss(

self_rank_is_1 = Rank(self) == 1
if self_rank_is_1: # self rank should be at least 2
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
self = op.Unsqueeze(self, [0])

rank_target = Rank(target)
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))
target = op.Unsqueeze(target, [0])

if reduction == 0:
reduction_str = "none"
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:

self_is_scalar = len(self.shape) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
self = op.Unsqueeze(self, [0])
result = op.LogSoftmax(self, axis=dim)
if dtype != -1:
result = op.Cast(result, to=dtype)
Expand Down
Loading