Skip to content

Commit

Permalink
fix for CI failure
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 6, 2024
1 parent 510b71f commit f0a3f1a
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2843,19 +2843,16 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
auto self = getSelf();
auto index = getIndex();
auto selfTy = cast<ValueTensorType>(self.getType());
auto indexTy = cast<ValueTensorType>(index.getType());
assert(index.getType().isa<IntegerType>());
auto resultTy = cast<ValueTensorType>(getType());

auto selfSizes = selfTy.getSizes();
auto indexSizes = indexTy.getSizes();
auto resultSizes = resultTy.getSizes();

if (selfTy.getDtype() != resultTy.getDtype())
return nullptr;
if (selfSizes.size() != resultSizes.size())
return nullptr;
if (indexSizes.size() != 1)
return nullptr;

// If the selection results in a tensor of the same dimensions as the
// input, the selection must have specified every index of the input,
Expand All @@ -2868,7 +2865,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
}

if (fullTensor && indexSizes[0] == 1)
if (fullTensor)
return self;

// If the input tensor, index dimension, or indexes are non-constant,
Expand Down

0 comments on commit f0a3f1a

Please sign in to comment.