Skip to content

Commit

Permalink
Add support for the padding variations of conv op (#3883)
Browse files Browse the repository at this point in the history
ConvOp defined with padding = "same"/"valid" produces the padding
variant of the op, such as `conv2d.padding` for 2d convolution. This PR
adds these conv variations to torch-mlir registry and a decomposition of
these ops to `aten.convolution` to be able to go through the different
pass pipelines.
  • Loading branch information
sahas3 authored Dec 5, 2024
1 parent 92d0f04 commit c1892de
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 0 deletions.
87 changes: 87 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6684,6 +6684,35 @@ def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [
}];
}

def Torch_AtenConv3dPaddingOp : Torch_Op<"aten.conv3d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv3dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv3dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6713,6 +6742,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
}];
}

def Torch_AtenConv2dPaddingOp : Torch_Op<"aten.conv2d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv2dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv2dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6742,6 +6800,35 @@ def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [
}];
}

def Torch_AtenConv1dPaddingOp : Torch_Op<"aten.conv1d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv1dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv1dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
63 changes: 63 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10024,10 +10024,65 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._conv_padding(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int-1 = torch.constant.int -1\n"
" %str = torch.constant.str \"same\"\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.mul.left_t %3, %2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %10, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n"
" %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %arg4 : !torch.list<int> to !torch.optional<list<int>>\n"
Expand Down Expand Up @@ -10097,6 +10152,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %int1 = torch.constant.int 1\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
Expand Down
82 changes: 82 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5175,6 +5175,82 @@ class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
};
} // namespace

// Decompose aten.conv(1/2/3)d.padding to aten.convolution
namespace {
template <typename ConvPaddingOp>
class DecomposeAtenConvPaddingOp : public OpRewritePattern<ConvPaddingOp> {
public:
using OpRewritePattern<ConvPaddingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConvPaddingOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();

Value weight = op.getWeight();
std::optional<unsigned> maybeRank = getTensorRank(weight);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected weight to have a rank");
}
unsigned rank = *maybeRank;
// first 2 dimensions of weight are out_channels and in_channels / groups
if (rank < 3)
return rewriter.notifyMatchFailure(
op, "ConvPaddingOp weight must be at least 3 dimensional.");

std::string padding_str;
if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str)))
return rewriter.notifyMatchFailure(op,
"padding must be a constant string");

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

SmallVector<Value> paddingValues;
if (padding_str == "valid") {
// valid means no padding
for (unsigned iRank = 2; iRank < rank; iRank++) {
paddingValues.push_back(zero);
}
} else {

SmallVector<Value> dilation;
getListConstructElements(op.getDilation(), dilation);

Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
for (unsigned iRank = 2; iRank < rank; iRank++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(iRank));
Value kernelSize =
rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
Value kernelSizeMinusOne =
rewriter.create<Torch::AtenSubIntOp>(loc, kernelSize, one);
Value padding = rewriter.create<Torch::AtenMulIntOp>(
loc, dilation[iRank - 2], kernelSizeMinusOne);
padding = rewriter.create<AtenFloordivIntOp>(loc, padding, two);
paddingValues.push_back(padding);
}
}

Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value padding = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
paddingValues);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), padding, op.getDilation(), cstFalse, emptyList,
op.getGroups());

return success();
}
};
} // namespace

// Decompose aten.conv3d to aten.convolution
namespace {
class DecomposeAtenConv3dOp : public OpRewritePattern<AtenConv3dOp> {
Expand Down Expand Up @@ -11377,6 +11453,12 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv1dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv2dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv3dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenThresholdOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloatPowerTensorTensorOp>(
patterns);
Expand Down
20 changes: 20 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,8 @@
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Convolution2DStaticModule_basic",
"CosineSimilarityStaticModule_basic",
"DetachModule_basic",
Expand Down Expand Up @@ -2557,6 +2559,8 @@
"Conv2dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
# failed to legalize operation 'torch.operator'
"ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic",
Expand Down Expand Up @@ -2886,6 +2890,8 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand All @@ -2898,7 +2904,11 @@
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv3dModule_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
Expand Down Expand Up @@ -3585,6 +3595,8 @@
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand All @@ -3595,6 +3607,8 @@
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
Expand Down Expand Up @@ -4178,6 +4192,8 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
Expand All @@ -4193,7 +4209,11 @@
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv2dWithPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv3dModule_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
Expand Down
Loading

0 comments on commit c1892de

Please sign in to comment.