Skip to content

Commit

Permalink
[MLIR][TORCH] Modify Onnx.Reshape lowering for static shape cases (ll…
Browse files Browse the repository at this point in the history
…vm#2852)

This commit modifies the OnnxToTorch lowering of Onnx.Reshape op by
creating the result shape list for the aten.reshape using the result
shape values inferred from the op's result shape.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Feb 8, 2024
1 parent a8aad2a commit 4df9661
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 142 deletions.
23 changes: 23 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(allowzero, "allowzero", 0))
return failure();

// If the result shape is static then we can create a result shape list
// directly using the result shape values (integers).
if (resultType.hasSizes()) {
bool hasStaticShape = resultType.areAllSizesKnown();
ArrayRef<int64_t> resultShapeInt = resultType.getSizes();
if (hasStaticShape) {
SmallVector<Value> resultShape;
for (int64_t dim : resultShapeInt) {
resultShape.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
}
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
resultShape);
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
binder.op, resultType, data, resultShapeList);
return success();
}
}

Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
Expand Down
161 changes: 19 additions & 142 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1256,33 +1256,11 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1:

// CHECK-LABEL: func.func @test_reshape_negative_dim
func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
// CHECK: %[[INT6:.+]] = torch.constant.int 6
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2_0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32>
return %0 : !torch.vtensor<[2,6,2],f32>
}
Expand All @@ -1291,40 +1269,12 @@ func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1:

// CHECK-LABEL: func.func @test_reshape_negative_extended_dims
func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT3_2:.+]] = torch.constant.int 3
// CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT2]], %[[INT3]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32>
return %0 : !torch.vtensor<[1,2,3,4],f32>
}
Expand All @@ -1333,17 +1283,9 @@ func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32

// CHECK-LABEL: func.func @test_reshape_one_dim
func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %6 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[24],f32>
// CHECK: %[[INT24:.+]] = torch.constant.int 24
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT24]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[24],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32>
return %0 : !torch.vtensor<[24],f32>
}
Expand All @@ -1352,25 +1294,10 @@ func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torc

// CHECK-LABEL: func.func @test_reshape_reduced_dims
func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %12 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,12],f32>
// CHECK: %[[INT12:.+]] = torch.constant.int 12
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT12]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,12],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}
Expand All @@ -1379,33 +1306,11 @@ func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1:

// CHECK-LABEL: func.func @test_reshape_reordered_all_dims
func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[4,2,3],f32>
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT4]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[4,2,3],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32>
return %0 : !torch.vtensor<[4,2,3],f32>
}
Expand All @@ -1414,40 +1319,12 @@ func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %

// CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim
func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT3_2:.+]] = torch.constant.int 3
// CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]], %[[INT1]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1,4],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
return %0 : !torch.vtensor<[2,3,1,4],f32>
}
Expand Down

0 comments on commit 4df9661

Please sign in to comment.