Skip to content

Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds #4144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 163 additions & 64 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,9 @@ namespace {
// used in the divisor of the average pooling operator.
template <int NumOfDims> class PoolSizeCalculator {
public:
PoolSizeCalculator(Value self, Value sumPool,
ConversionPatternRewriter &rewriter, Location loc);
PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad,
bool ceilMode, ConversionPatternRewriter &rewriter,
Location loc);

// The algorithm for computing the divisor with
// count_include_pad equal is mainly based on pytorch
Expand All @@ -871,18 +872,20 @@ template <int NumOfDims> class PoolSizeCalculator {
SmallVectorImpl<int64_t> &paddingInts);

private:
int64_t DimSizeFromSumPoolType[NumOfDims];
Value InputSpatialDimValues[NumOfDims];
int64_t SumPoolTypeDimIndex[NumOfDims];
Value InputSpatialDimSizes[NumOfDims];
Location location;
bool isCountIncludePad;
bool isCeilMode;
};

} // namespace

template <int NumOfDims>
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
Value self, Value sumPool, ConversionPatternRewriter &rewriter,
Location loc)
: location(loc) {
Value self, Value sumPool, bool countIncludePad, bool ceilMode,
ConversionPatternRewriter &rewriter, Location loc)
: location(loc), isCountIncludePad(countIncludePad), isCeilMode(ceilMode) {
auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank();
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
Expand All @@ -891,57 +894,124 @@ PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
// Store dimensions in this order:
// 0 => width, 1 => height, 2 => depth
for (int i = 0; i < NumOfDims; ++i) {
int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimValues[i] =
getDimOp(rewriter, location, self, DimSizeFromSelfType);
DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank);
int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimSizes[i] =
getDimOp(rewriter, location, self, inputSpatialDimIndex);
SumPoolTypeDimIndex[i] = toPositiveDim(-(i + 1), rank);
}
}

template <int NumOfDims>
Value PoolSizeCalculator<NumOfDims>::getPoolSize(
OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues,
OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
Value poolSize;

Value cstZero =
b.createOrFold<arith::ConstantOp>(location, b.getI64IntegerAttr(0));
Value cstOne =
b.createOrFold<arith::ConstantOp>(location, b.getI64IntegerAttr(1));
Value cstTwo =
b.createOrFold<arith::ConstantOp>(location, b.getI64IntegerAttr(2));

for (int i = 0; i < NumOfDims; ++i) {
// See the link below for the PyTorch implementation where this is
// derived from:
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
// Dim below stands for spatial dimension. Prior to the February 2025
// change, these variables used "height" and "width" (or "h" and "w")
// in these intermediate variables instead of "Dim".
Value IndexODim =
// The following code computes the clamped kernel size used to compute
// the divisor of the average pooling operator. Here is the formula that
// it represents:
//
// indexStartOffset = ceil((kernelSize - 1)/2) - padding
//
// clampedKernelSize =
// min(outIntIndex * stride + indexStartOffset + floor((kernelSize - 1)/2)
// + 1,
// InputSpatialDimSize + padding) -
// max(outIntIndex * stride + indexStartOffset - ceil((kernelSize - 1)/2),
// -padding)
//
// The outIntIndex is the current iteration value coming from the
// linalg.generic op and it represents the center of the kernel window.
// The padding above becomes zero if count_include_pad is false.
// The kernelSize - 1 is used to subtract the center element of the kernel
// from the kernel size before dividing by two. Note that PyTorch even
// kernel dimensions are biased to the lower side of the dimension. Hence
// the lower length uses ceiling. While the upper length uses floor.
//
// If count_include_pad is true, in most cases the divisor is just the
// product of kernel dimensions. But we still need this logic for the
// case in which the ceiling mode is true since the kernel window
// center can go into the padding outside of the input tensor. This
// introduces an implicit padding that is not controlled by the
// count_include_pad parameter. See the
// AvgPool2dCeilPaddingStridedIncludePadding E2E test for details.

// The average pool properties of kernel size, strides, and padding are
// stored in the reverse order of the input tensor dimensions. The
// following code computes the index of the average pool property that
// corresponds to the current spatial dimension.
int avgPoolPropIdx = NumOfDims - i - 1;

Value padding = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx]));
Value InputSpatialDimSize =
castIndexToInt64(b, location, InputSpatialDimSizes[i]);
// Subtract center element from kernel size before division by two.
Value kernelSizeMinusOne = b.createOrFold<arith::SubIOp>(
location, kernelDimSizes[avgPoolPropIdx], cstOne);
// PyTorch even kernel dimensions are biased to the lower side of the
// dimension. Hence the lower lenght uses ceiling.
Value kernelLowerLength = b.createOrFold<arith::CeilDivSIOp>(
location, kernelSizeMinusOne, cstTwo);
// While the upper length uses floor.
Value kernelUpperLength = b.createOrFold<arith::FloorDivSIOp>(
location, kernelSizeMinusOne, cstTwo);

// The more padding the closest we can read from the lower bound of
// the input tensor.
Value indexStartOffset =
b.createOrFold<arith::SubIOp>(location, kernelLowerLength, padding);

Value outIndex =
b.createOrFold<linalg::IndexOp>(location,
/*value=*/DimSizeFromSumPoolType[i]);
Value ODim = castIndexToInt64(b, location, IndexODim);
Value DDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(strideInts[i]));
Value PadDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(paddingInts[i]));
Value ODimDDim = b.createOrFold<arith::MulIOp>(location, ODim, DDim);
Value IDim0 = b.createOrFold<arith::SubIOp>(location, ODimDDim, PadDim);
Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]);
Value IDim0KDim =
b.createOrFold<arith::AddIOp>(location, IDim0, kernelSizeIntValues[i]);
Value IDimPadDim = b.createOrFold<arith::AddIOp>(location, IDim, PadDim);
Value IDim1 =
b.createOrFold<arith::MinSIOp>(location, IDim0KDim, IDimPadDim);

Value IDim0Clamped =
b.createOrFold<arith::MaxSIOp>(location, IDim0, cstZero);
Value IDim1Clamped = b.createOrFold<arith::MinSIOp>(location, IDim1, IDim);
Value IDim1_IDim0_Clamped =
b.createOrFold<arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);
/*value=*/SumPoolTypeDimIndex[i]);
Value outIntIndex = castIndexToInt64(b, location, outIndex);

Value stride = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx]));

Value indexStrided = b.createOrFold<arith::AddIOp>(
location, b.createOrFold<arith::MulIOp>(location, outIntIndex, stride),
indexStartOffset);

Value inputUpperBound = isCountIncludePad
? b.createOrFold<arith::AddIOp>(
location, InputSpatialDimSize, padding)
: InputSpatialDimSize;

Value inputLowerBound =
isCountIncludePad
? b.createOrFold<arith::SubIOp>(location, cstZero, padding)
: cstZero;

Value upperBoundMinusOne = b.createOrFold<arith::AddIOp>(
location, indexStrided, kernelUpperLength);
Value upperBound =
b.createOrFold<arith::AddIOp>(location, upperBoundMinusOne, cstOne);
Value upperBoundClamped =
b.createOrFold<arith::MinSIOp>(location, upperBound, inputUpperBound);

Value lowerBound = b.createOrFold<arith::SubIOp>(location, indexStrided,
kernelLowerLength);
Value lowerBoundClamped =
b.createOrFold<arith::MaxSIOp>(location, lowerBound, inputLowerBound);
Value clampedKernelSize = b.createOrFold<arith::SubIOp>(
location, upperBoundClamped, lowerBoundClamped);

if (i == 0) {
poolSize = IDim1_IDim0_Clamped;
poolSize = clampedKernelSize;
} else {
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize,
IDim1_IDim0_Clamped);
poolSize =
b.createOrFold<arith::MulIOp>(location, poolSize, clampedKernelSize);
}
}
return poolSize;
Expand All @@ -961,10 +1031,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
// count_include_pad parameter is equal to false.
static std::optional<LogicalResult>
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
Expand All @@ -976,7 +1046,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg);
};
Expand Down Expand Up @@ -1041,9 +1111,9 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
Dim + 2, utils::IteratorType::parallel);

auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase(
countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
iteratorTypesAvg);
ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
indexingMapsAvg, iteratorTypesAvg);
if (divisorOpResult)
return *divisorOpResult;

Expand All @@ -1057,10 +1127,10 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
template <typename OpTy, typename PoolingOpTy, int Dim>
std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
Expand All @@ -1069,8 +1139,37 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();

bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
if (countIncludePad || noPadding) {
bool hasPadding =
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
bool allStridesUnitary =
llvm::all_of(strideInts, [](int64_t s) { return s == 1; });

// If the condition below is true, the divisor total must subtract the
// elements not counted (clamped divisor count). If false, the divisor
// is just the product of kernel dimensions.
bool divisorIsClamped =
(!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
// There are two ways to get the divisor clamped: through padding or
// ceiling mode. For the case when there is padding, the padding elements
// are omitted if count_include_pad == False (divisor is clamped). If
// there is no padding (padding == 0) then the count_include_pad value
// does not take effect.
// The divisor count can be clamped also through the ceil_mode. In this
// case, according to the Hout and Wout formula in this page:
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d,
// the ceil_mode will round up on the stride division. The round up
// will give an extra element that will go out of bounds which PyTorch
// adds zero padding in it. It also does not count the implicit zero
// padding elements in the divisor, and it is not controlled by the
// count_include_pad argument.
// But also note that if all strides are 1 there is not fractions to
// round up, hence there is no ceiling rounding and the window will
// not go out of bounds. For this case the divisor is just the
// product of kernel dimensions.
// Search for torch.nn.AvgPool2d E2E tests for coverage of these
// conditions.

if (!divisorIsClamped) {
// These cases are not handled here.
return std::nullopt;
}
Expand All @@ -1082,8 +1181,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

PoolSizeCalculator<avgPoolDims> poolSizeCalculator(self, sumPool, rewriter,
loc);
PoolSizeCalculator<avgPoolDims> poolSizeCalculator(
self, sumPool, countIncludePad, ceilMode, rewriter, loc);

// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
Expand All @@ -1104,7 +1203,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
[&](OpBuilder &b, Location loc, ValueRange args) {
if (!poolSize) {
poolSize = poolSizeCalculator.getPoolSize(
b, kernelSizeIntValues, strideInts, paddingInts);
b, kernelDimSizes, strideInts, paddingInts);
}
Value divisor =
convertScalarToDtype(b, loc, poolSize, resultElementType);
Expand All @@ -1126,17 +1225,17 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg) {
Location loc = op->getLoc();

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

Value divisor = kernelSizeIntValues[0];
for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) {
divisor = rewriter.createOrFold<arith::MulIOp>(loc, divisor,
kernelSizeIntValues[i]);
Value divisor = kernelDimSizes[0];
for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) {
divisor =
rewriter.createOrFold<arith::MulIOp>(loc, divisor, kernelDimSizes[i]);
}
// Only average pooling 2D/3D have optional divisor override.
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,13 @@
"Aten_EmbeddingBagExample_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCeilNoPadStridedIncludePadding_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
Expand Down Expand Up @@ -2788,6 +2795,10 @@
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"AvgPool2dWithoutPadModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"BatchMlpLayerModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
Expand Down Expand Up @@ -3527,6 +3538,11 @@
"AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
Expand Down Expand Up @@ -3932,6 +3948,9 @@
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueModule_basic",
"AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic",
"AvgPool2dCeilNoPadUnitaryStrides_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"AvgPool3dStaticModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Loading
Loading