Skip to content

Commit

Permalink
index_select fixes, added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 6, 2024
1 parent b46c489 commit 76582b4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
18 changes: 14 additions & 4 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,17 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
if (!scalarFold)
return nullptr;

auto indexInt = indexAttr.getSplatValue<IntegerAttr>().getInt();
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
uint64_t indexInt = 0;
if (splatValue.getType().isSignedInteger())
indexInt = uint64_t(splatValue.getSInt());
else if (splatValue.getType().isUnsignedInteger())
indexInt = splatValue.getUInt();
else if (splatValue.getType().isSignlessInteger())
indexInt = uint64_t(splatValue.getInt());
else
return nullptr;

auto splattr = selfAttr.getValues<Attribute>()[indexInt];

auto dty = resultTy.getDtype();
Expand All @@ -2885,10 +2895,10 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));

if (auto intAttr = dyn_cast<IntegerAttr>(splattr))
if (auto intAttr = dyn_cast<IntegerAttr>(splattr)) {
return DenseElementsAttr::get(attrTy,
IntegerAttr::get(dty, intAttr.getInt()));

IntegerAttr::get(dty, intAttr.getValue()));
}
return nullptr;
}

Expand Down
29 changes: 25 additions & 4 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2197,19 +2197,40 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor
// CHECK-LABEL: func.func @torch.aten.index_select$noop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64>
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> (!torch.vtensor<[1,2,3],si64>) {
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> {
%0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64>
return %0 : !torch.vtensor<[1,2,3],si64>
}

// CHECK-LABEL: func.func @torch.aten.index_select$const(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
func.func @torch.aten.index_select$const(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> (!torch.vtensor<[1],si64>) {
func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> {
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}

// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> {
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}

// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> {
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

0 comments on commit 76582b4

Please sign in to comment.