Skip to content

Commit

Permalink
remove extract slice from grid sample
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 committed Jul 1, 2024
1 parent 0e71a19 commit 77f4eea
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 56 deletions.
81 changes: 27 additions & 54 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "PopulatePatterns.h"
Expand Down Expand Up @@ -2380,7 +2381,6 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
Type floatType = rewriter.getF32Type();
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value twoIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2);
Value zeroFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.0));
Value oneFloat = rewriter.create<arith::ConstantOp>(
Expand All @@ -2389,7 +2389,6 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
loc, rewriter.getFloatAttr(floatType, 2.0));
Value input = adaptor.getInput();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
Value innerDim0b =
Expand All @@ -2410,43 +2409,12 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
Value grid = adaptor.getGrid();
auto gridType = cast<RankedTensorType>(grid.getType());
auto gridShape = gridType.getShape();
auto gridRank = gridType.getRank();
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
SmallVector<Value> extractGridShape = getTensorSizes(rewriter, loc, grid);
SmallVector<Value> extractGridStride(gridRank, oneIndex);
int64_t lastGridDim = gridRank - 1;
extractGridShape[lastGridDim] = oneIndex;
extractGridStride[lastGridDim] = twoIndex;
SmallVector<Value> extractGridOffsets1(gridRank, zeroIndex);
extractGridOffsets1[lastGridDim] = oneIndex;
SmallVector<int64_t> gridShapeExtracted(gridShape);
gridShapeExtracted.back() = 1;
SmallVector<int64_t> gridShapeCollapsed{gridShape[0], gridShape[1],
gridShape[2]};
auto grid0 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets0, extractGridShape, extractGridStride);
auto grid1 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets1, extractGridShape, extractGridStride);
SmallVector<ReassociationIndices> associations{ReassociationIndices{0},
ReassociationIndices{1},
ReassociationIndices{2, 3}};
auto gridCollapsed0 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid0, associations);
auto gridCollapsed1 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid1, associations);
AffineMap gridMap = AffineMap::get(4, 0,
{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3)},
op->getContext());
SmallVector<AffineMap> gridMaps{gridMap, gridMap,
rewriter.getMultiDimIdentityMap(gridRank)};
SmallVector<AffineMap> gridMaps{
rewriter.getMultiDimIdentityMap(inputType.getRank())};
SmallVector<utils::IteratorType> gridIterators(
gridRank, utils::IteratorType::parallel);
SmallVector<int64_t> resultShape{inputShape[0], inputShape[1], gridShape[1],
gridShape[2]};
auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
auto createExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
Value idxB, Value idxC, Value idxD) -> Value {
SmallVector<Value> index{idxA, idxB, idxC, idxD};
Value result = b.create<tensor::ExtractOp>(loc, input, index);
Expand Down Expand Up @@ -2486,25 +2454,30 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {

auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<Value> resultSize{};
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();
SmallVector<Value> dynamicSizes{};
if (resultType.isDynamicDim(0))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (resultType.isDynamicDim(1))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (resultType.isDynamicDim(2))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
if (resultType.isDynamicDim(3))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();
Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
tensor::EmptyOp emptyOp =
rewriter.create<tensor::EmptyOp>(loc, resultType, dynamicSizes);
auto sGrid = rewriter.create<linalg::GenericOp>(
loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1},
ValueRange(resultFinal), gridMaps, gridIterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gr0 = args[1];
Value gr1 = args[0];
loc, TypeRange{resultType}, ValueRange(), ValueRange(emptyOp), gridMaps,
gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) {
Value iterIdxZero = b.create<linalg::IndexOp>(loc, 0);
Value iterIdxTwo = b.create<linalg::IndexOp>(loc, 2);
Value iterIdxThree = b.create<linalg::IndexOp>(loc, 3);

Value gr1 = createExtract(b, loc, grid, iterIdxZero, iterIdxTwo,
iterIdxThree, zeroIndex);
Value gr0 = createExtract(b, loc, grid, iterIdxZero, iterIdxTwo,
iterIdxThree, oneIndex);
Value gr0Half = b.create<arith::DivFOp>(loc, gr0, twoFloat);
Value gr1Half = b.create<arith::DivFOp>(loc, gr1, twoFloat);
Value gr0HalfSelect =
Expand Down Expand Up @@ -2561,22 +2534,22 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid1);
Value N = b.create<linalg::IndexOp>(loc, 0);
Value C = b.create<linalg::IndexOp>(loc, 1);
Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1);
Value result00 = createExtract(b, loc, input, N, C, lw0, lw1);
Value result00a = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result00);
Value result00b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result00a);
Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1);
Value result01 = createExtract(b, loc, input, N, C, lw0, up1);
Value result01a = b.create<arith::SelectOp>(loc, notValidUpper1,
zeroFloat, result01);
Value result01b = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result01a);
Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1);
Value result10 = createExtract(b, loc, input, N, C, up0, lw1);
Value result10a = b.create<arith::SelectOp>(loc, notValidUpper0,
zeroFloat, result10);
Value result10b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result10a);
Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1);
Value result11 = createExtract(b, loc, input, N, C, up0, up1);
Value result11a = b.create<arith::SelectOp>(loc, notValidUpper0,
zeroFloat, result11);
Value result11b = b.create<arith::SelectOp>(loc, notValidUpper1,
Expand Down
2 changes: 0 additions & 2 deletions test/Conversion/TorchToLinalg/gridsampler.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
Expand Down

0 comments on commit 77f4eea

Please sign in to comment.