@@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
1523
1523
raise NotImplementedError ()
1524
1524
1525
1525
1526
- @torch_op ("aten::broadcast_to" )
1527
- def aten_broadcast_to (self : TTensor , size : INT64 ) -> TTensor :
1526
+ @torch_op ("aten::broadcast_to" , trace_only = True )
1527
+ def aten_broadcast_to (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
1528
1528
"""broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
1529
-
1529
+ size = common_ops . merge_dims ( size )
1530
1530
return op .Expand (self , size )
1531
1531
1532
1532
@@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward(
3286
3286
3287
3287
@torch_op ("aten::empty.memory_format" , trace_only = True )
3288
3288
def aten_empty (
3289
- size : IntType ,
3289
+ size : Sequence [ INT64 ] ,
3290
3290
dtype : int = FLOAT .dtype ,
3291
3291
layout : str = "" ,
3292
3292
device : str = "" ,
3293
3293
pin_memory : bool = False ,
3294
3294
memory_format : str = "" ,
3295
3295
) -> TensorType : # type: ignore[type-var]
3296
- # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
3296
+ """ empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
3297
3297
if dtype == - 1 :
3298
3298
dtype = FLOAT .dtype
3299
- # using Zeros to simulate np.empty()
3300
- size = op . Cast ( size , to = INT64 . dtype )
3301
- zero = op .Constant (value_float = 0.0 )
3302
- zero = op . Cast ( zero , to = dtype )
3299
+
3300
+ # using Zeros to simulate empty( )
3301
+ zero = op .Constant (value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
3302
+ size = common_ops . merge_dims ( size )
3303
3303
3304
3304
return op .Expand (zero , size )
3305
3305
@@ -3334,17 +3334,18 @@ def aten_empty_quantized(
3334
3334
3335
3335
@torch_op ("aten::empty_strided" , trace_only = True )
3336
3336
def aten_empty_strided (
3337
- size : INT64 ,
3337
+ size : Sequence [ INT64 ] ,
3338
3338
stride : INT64 ,
3339
3339
layout : str = "" ,
3340
+ dtype : int = FLOAT .dtype ,
3340
3341
device : str = "" ,
3341
3342
pin_memory : bool = False ,
3342
3343
) -> TTensor : # type: ignore[type-var]
3343
3344
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
3344
3345
3345
3346
# using Zeros to simulate empty()
3346
- size = op .Cast ( size , to = INT64 . dtype )
3347
- zero = op . Constant ( value_float = 0.0 )
3347
+ zero = op .Constant ( value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
3348
+ size = common_ops . merge_dims ( size )
3348
3349
3349
3350
return op .Expand (zero , size )
3350
3351
@@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat:
3392
3393
3393
3394
3394
3395
@torch_op ("aten::expand" , trace_only = True )
3395
- def aten_expand (self : TTensor , size : TInt , implicit : bool = False ) -> TTensor :
3396
+ def aten_expand (self : TTensor , size : Sequence [ INT64 ] , implicit : bool = False ) -> TTensor :
3396
3397
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
3397
- size = op .Cast (size , to = INT64 .dtype )
3398
3398
# NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
3399
3399
# To support -1 dim, we need to convert -1 to 1.
3400
- size = op .Abs (size )
3401
- return op .Expand (self , size )
3400
+ # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely
3401
+ # and isn't expected to appear from correct usages of SymInt.
3402
+ size = [1 if isinstance (s , int ) and s == - 1 else s for s in size ]
3403
+ return op .Expand (self , common_ops .merge_dims (size ))
3402
3404
3403
3405
3404
3406
@torch_op ("aten::expand_as" , trace_only = True )
@@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor(
7409
7411
)
7410
7412
7411
7413
7412
- @torch_op ("aten::reshape" )
7413
- def aten_reshape (self : TTensor , shape : IntType ) -> TTensor :
7414
+ @torch_op ("aten::reshape" , trace_only = True )
7415
+ def aten_reshape (self : TTensor , shape : Sequence [ INT64 ] ) -> TTensor :
7414
7416
"""reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""
7415
-
7416
- # Reshape only support INT64 as 'shape'
7417
- shape = op .Cast (shape , to = INT64 .dtype )
7417
+ shape = common_ops .merge_dims (shape )
7418
7418
return op .Reshape (self , shape )
7419
7419
7420
7420
@@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
9153
9153
9154
9154
9155
9155
@torch_op (("aten::view" , "aten::_unsafe_view" ), trace_only = True )
9156
- def aten_view (self : TTensor , size : IntType ) -> TTensor :
9156
+ def aten_view (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
9157
9157
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
9158
9158
9159
- size = op . Cast (size , to = INT64 . dtype ) # Reshape only support INT64 as second input
9159
+ size = common_ops . merge_dims (size )
9160
9160
return op .Reshape (self , size , allowzero = True )
9161
9161
9162
9162
9163
- @torch_op (("aten::view" , "aten::_unsafe_view" ), complex = True )
9164
- def aten_view_complex (self : TTensor , size : IntType ) -> TTensor :
9163
+ @torch_op (("aten::view" , "aten::_unsafe_view" ), complex = True , trace_only = True )
9164
+ def aten_view_complex (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
9165
9165
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
9166
9166
9167
- size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
9168
- complex_size = op .Concat (size , op .Constant (value_ints = [2 ]), axis = 0 )
9167
+ complex_size = common_ops .merge_dims ([* size , 2 ])
9169
9168
return op .Reshape (self , complex_size , allowzero = True )
9170
9169
9171
9170
9172
- @torch_op ("aten::view_as" )
9171
+ @torch_op ("aten::view_as" , trace_only = True )
9173
9172
def aten_view_as (self : TTensor , other : TTensor2 ) -> TTensor :
9174
9173
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
9175
9174
@@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:
9213
9212
return op .Identity (self )
9214
9213
9215
9214
9216
- @torch_op ("aten::view_copy" )
9217
- def aten_view_copy (self : TTensor , size : IntType ) -> TTensor :
9215
+ @torch_op ("aten::view_copy" , trace_only = True )
9216
+ def aten_view_copy (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
9218
9217
"""view_copy(Tensor self, SymInt[] size) -> Tensor"""
9219
9218
9220
- size = op . Cast (size , to = INT64 . dtype ) # Reshape only support INT64 as second input
9219
+ size = common_ops . merge_dims (size )
9221
9220
return op .Reshape (self , size )
9222
9221
9223
9222
@@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor):
9245
9244
"aten::where.ScalarSelf" ,
9246
9245
"aten::where.ScalarOther" ,
9247
9246
"aten::where.self" ,
9248
- )
9247
+ ),
9248
+ trace_only = True ,
9249
9249
)
9250
9250
def aten_where (condition : BOOL , self : TTensor , other : TTensor ) -> TTensor :
9251
9251
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
@@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
9261
9261
9262
9262
@torch_op ("aten::zeros" , trace_only = True )
9263
9263
def aten_zeros (
9264
- size : IntType ,
9264
+ size : Sequence [ INT64 ] ,
9265
9265
dtype : int = FLOAT .dtype ,
9266
9266
layout : str = "" ,
9267
9267
device : str = "" ,
@@ -9270,9 +9270,9 @@ def aten_zeros(
9270
9270
"""zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
9271
9271
if dtype == - 1 :
9272
9272
dtype = FLOAT .dtype
9273
- size = op . Cast ( size , to = INT64 . dtype )
9274
- zero = op .Constant (value_float = 0.0 )
9275
- zero = op . Cast ( zero , to = dtype )
9273
+
9274
+ zero = op .Constant (value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
9275
+ size = common_ops . merge_dims ( size )
9276
9276
9277
9277
return op .Expand (zero , size )
9278
9278
0 commit comments