From 77f4eeade2d8007b13689890baa784b696cb6c19 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 20 Jun 2024 17:31:24 +0000 Subject: [PATCH] remove extract slice from grid sample --- .../TorchToLinalg/Uncategorized.cpp | 81 +++++++------------ .../Conversion/TorchToLinalg/gridsampler.mlir | 2 - 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5e5f86065201..5e46bb1890a7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "PopulatePatterns.h" @@ -2380,7 +2381,6 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Type floatType = rewriter.getF32Type(); Value zeroIndex = rewriter.create(loc, 0); Value oneIndex = rewriter.create(loc, 1); - Value twoIndex = rewriter.create(loc, 2); Value zeroFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 0.0)); Value oneFloat = rewriter.create( @@ -2389,7 +2389,6 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); auto inputType = cast(input.getType()); - auto inputShape = inputType.getShape(); Value innerDim0a = rewriter.create(loc, input, 2); Value innerDim1a = rewriter.create(loc, input, 3); Value innerDim0b = @@ -2410,43 +2409,12 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { rewriter.create(loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); auto gridType = cast(grid.getType()); - auto gridShape = gridType.getShape(); auto gridRank = gridType.getRank(); - SmallVector extractGridOffsets0(gridRank, zeroIndex); - SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); - SmallVector extractGridStride(gridRank, oneIndex); - int64_t lastGridDim = gridRank - 1; - extractGridShape[lastGridDim] = oneIndex; - extractGridStride[lastGridDim] = twoIndex; - SmallVector extractGridOffsets1(gridRank, zeroIndex); - extractGridOffsets1[lastGridDim] = oneIndex; - SmallVector gridShapeExtracted(gridShape); - gridShapeExtracted.back() = 1; - SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], - gridShape[2]}; - auto grid0 = rewriter.create( - loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); - auto grid1 = rewriter.create( - loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); - SmallVector associations{ReassociationIndices{0}, - ReassociationIndices{1}, - ReassociationIndices{2, 3}}; - auto gridCollapsed0 = - rewriter.create(loc, grid0, associations); - auto gridCollapsed1 = - rewriter.create(loc, grid1, associations); - AffineMap gridMap = AffineMap::get(4, 0, - {rewriter.getAffineDimExpr(0), - rewriter.getAffineDimExpr(2), - rewriter.getAffineDimExpr(3)}, - op->getContext()); - SmallVector gridMaps{gridMap, gridMap, - rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridMaps{ + rewriter.getMultiDimIdentityMap(inputType.getRank())}; SmallVector gridIterators( gridRank, utils::IteratorType::parallel); - SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], - gridShape[2]}; - auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, + auto createExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, Value idxB, Value idxC, Value idxD) -> Value { SmallVector index{idxA, idxB, idxC, idxD}; Value result = b.create(loc, input, index); @@ -2486,25 +2454,30 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); - SmallVector resultSize{}; + Value alignCorners = adaptor.getAlignCorners(); + Value interMode = adaptor.getInterpolationMode(); + SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) - resultSize.push_back(rewriter.create(loc, input, 0)); + dynamicSizes.push_back(rewriter.create(loc, input, 0)); if (resultType.isDynamicDim(1)) - resultSize.push_back(rewriter.create(loc, input, 1)); + dynamicSizes.push_back(rewriter.create(loc, input, 1)); if (resultType.isDynamicDim(2)) - resultSize.push_back(rewriter.create(loc, grid, 1)); + dynamicSizes.push_back(rewriter.create(loc, grid, 1)); if (resultType.isDynamicDim(3)) - resultSize.push_back(rewriter.create(loc, grid, 2)); - Value alignCorners = adaptor.getAlignCorners(); - Value interMode = adaptor.getInterpolationMode(); - Value resultFinal = - rewriter.create(loc, resultType, resultSize); + dynamicSizes.push_back(rewriter.create(loc, grid, 2)); + tensor::EmptyOp emptyOp = + rewriter.create(loc, resultType, dynamicSizes); auto sGrid = rewriter.create( - loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, - ValueRange(resultFinal), gridMaps, gridIterators, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value gr0 = args[1]; - Value gr1 = args[0]; + loc, TypeRange{resultType}, ValueRange(), ValueRange(emptyOp), gridMaps, + gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { + Value iterIdxZero = b.create(loc, 0); + Value iterIdxTwo = b.create(loc, 2); + Value iterIdxThree = b.create(loc, 3); + + Value gr1 = createExtract(b, loc, grid, iterIdxZero, iterIdxTwo, + iterIdxThree, zeroIndex); + Value gr0 = createExtract(b, loc, grid, iterIdxZero, iterIdxTwo, + iterIdxThree, oneIndex); Value gr0Half = b.create(loc, gr0, twoFloat); Value gr1Half = b.create(loc, gr1, twoFloat); Value gr0HalfSelect = @@ -2561,22 +2534,22 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { b.create(loc, b.getIndexType(), upperValid1); Value N = b.create(loc, 0); Value C = b.create(loc, 1); - Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); + Value result00 = createExtract(b, loc, input, N, C, lw0, lw1); Value result00a = b.create(loc, checkLowerBound0, zeroFloat, result00); Value result00b = b.create(loc, checkLowerBound1, zeroFloat, result00a); - Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); + Value result01 = createExtract(b, loc, input, N, C, lw0, up1); Value result01a = b.create(loc, notValidUpper1, zeroFloat, result01); Value result01b = b.create(loc, checkLowerBound0, zeroFloat, result01a); - Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); + Value result10 = createExtract(b, loc, input, N, C, up0, lw1); Value result10a = b.create(loc, notValidUpper0, zeroFloat, result10); Value result10b = b.create(loc, checkLowerBound1, zeroFloat, result10a); - Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); + Value result11 = createExtract(b, loc, input, N, C, up0, up1); Value result11a = b.create(loc, notValidUpper0, zeroFloat, result11); Value result11b = b.create(loc, notValidUpper1, diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 7c099c5ce4f6..2a291f721fed 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -5,9 +5,7 @@ // CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> // CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32