Skip to content

Commit

Permalink
[mlir][linalg] Implement TilingInterface for winograd operators (llvm…
Browse files Browse the repository at this point in the history
…#96184)

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use
tile_using_for to tile the input data into supported size. The test case
shows how to tile and decompose winograd operations.
  • Loading branch information
Hsiangkai authored Aug 16, 2024
1 parent 2fe59d5 commit c4bf949
Show file tree
Hide file tree
Showing 8 changed files with 1,330 additions and 40 deletions.
141 changes: 135 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}

def Linalg_WinogradFilterTransformOp :
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
[AllElementTypesMatch<["filter", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getFilterOperandType() {
return cast<ShapedType>(getFilter().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getFilterOperandRank() {
return getFilterOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getFilterFDim() {
return 0;
}
int64_t getFilterHDim() {
return 1;
}
int64_t getFilterWDim() {
return 2;
}
int64_t getFilterCDim() {
return 3;
}
}];
let hasVerifier = 1;
}

def Linalg_WinogradInputTransformOp :
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
[AllElementTypesMatch<["input", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getInputOperandType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getInputOperandRank() {
return getInputOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getInputNDim() {
return 0;
}
int64_t getInputHDim() {
return 1;
}
int64_t getInputWDim() {
return 2;
}
int64_t getInputCDim() {
return 3;
}
int64_t getOutputAlphaHDim() {
return 0;
}
int64_t getOutputAlphaWDim() {
return 1;
}
int64_t getOutputTileHDim() {
return 2;
}
int64_t getOutputTileWDim() {
return 3;
}
int64_t getOutputNDim() {
return 4;
}
int64_t getOutputCDim() {
return 5;
}
}];
let hasVerifier = 1;
}

def Linalg_WinogradOutputTransformOp :
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
[AllElementTypesMatch<["value", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getValueOperandType() {
return cast<ShapedType>(getValue().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getValueOperandRank() {
return getValueOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getValueAlphaHDim() {
return 0;
}
int64_t getValueAlphaWDim() {
return 1;
}
int64_t getValueTileHDim() {
return 2;
}
int64_t getValueTileWDim() {
return 3;
}
int64_t getValueNDim() {
return 4;
}
int64_t getValueFDim() {
return 5;
}
int64_t getOutputNDim() {
return 0;
}
int64_t getOutputHDim() {
return 1;
}
int64_t getOutputWDim() {
return 2;
}
int64_t getOutputFDim() {
return 3;
}
}];
let hasVerifier = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}

def DecomposeWinogradOp : Op<Transform_Dialect,
"structured.decompose_winograd_op",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Decompose winograd operations. It will convert filter, input and output
transform operations into a combination of scf, tensor, and linalg
equivalent operations. Before applying this transform operations, users
need to tile winograd transform operations into supported sizes.

#### Return modes:

This operation fails if `target` is unsupported. Otherwise, the operation
succeeds and returns a handle of the sequence that replaces the original
operations.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // LINALG_TRANSFORM_OPS
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,63 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);

/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
/// the rewriting, we get
///
/// scf.for %f = lo_f to hi_f step 1
/// scf.for %c = lo_c to hi_c step 1
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
/// %ret = linalg.matmul G, %extracted
/// %ret = linalg.matmul %ret, GT
/// %inserted = insert %ret into filter<h x w x c x f>
FailureOr<Operation *>
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
linalg::WinogradFilterTransformOp op);

/// Rewrite linalg.winograd_input_transform. The data layout of the input is
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
/// and tileW. After the rewriting, we get
///
/// scf.for %h = 0 to tileH step 1
/// scf.for %w = 0 to tileW step 1
/// scf.for %n = 0 to N step 1
/// scf.for %c = 0 to C step 1
/// %extracted = extract %extracted<alphaH x alphaW> from
/// %input<N x H x W x C>
/// at [%n, (%h x m), (%w x m), %c]
/// %ret = linalg.matmul BT, %extracted
/// %ret = linalg.matmul %ret, B
/// %inserted = insert %ret<alphaH x alphaW> into
/// %output<alphaH x alphaW x tileH x tileW x N x C>
/// at [0, 0, %h, %w, %n, %c]
FailureOr<Operation *>
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op);

/// Rewrite linalg.winograd_output_transform. The data layout of the output is
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
/// and tileW. After the transformation, we get
///
/// scf.for %h = 0 to tileH step 1
/// scf.for %w = 0 to tileW step 1
/// scf.for %n = 0 to N step 1
/// scf.for %f = 0 to F step 1
/// %extracted = extract %extracted<alphaH x alphaW> from
/// %input<alphaH x alphaW x tileH x tileW x N x F>
/// at [0, 0, %h, %w, %n, %f]
/// %ret = linalg.matmul AT, %extracted
/// %ret = linalg.matmul %ret, A
/// %inserted = insert %ret<alphaH x alphaW> into
/// output<N x H x W x F>
/// at [%n, (%h x m), (%w x m), %f]
FailureOr<Operation *>
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
linalg::WinogradOutputTransformOp op);

//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
Expand Down
Loading

0 comments on commit c4bf949

Please sign in to comment.