Skip to content

Commit b43ea70

Browse files
authored
Merge d37eaf1 into e38cc7f
2 parents e38cc7f + d37eaf1 commit b43ea70

File tree

2 files changed

+185
-61
lines changed

2 files changed

+185
-61
lines changed

compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp

+63-61
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
#include "iree-dialects/Transforms/TransformMatchers.h"
88
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
99
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
10-
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
11-
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1210
#include "iree/compiler/GlobalOptimization/PassDetail.h"
1311
#include "iree/compiler/GlobalOptimization/Passes.h"
1412
#include "llvm/ADT/STLExtras.h"
@@ -701,67 +699,71 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern<OpTy> {
701699
// Returns true if the operand was updated to inform the pattern rewriter
702700
// of a change.
703701
Type outElementType = getElementTypeOrSelf(namedOp->getResultTypes()[0]);
704-
bool didChangeOperand = false;
705-
{
706-
OpBuilder::InsertionGuard guard(rewriter);
707-
Block *block = &namedOp->getRegion(0).front();
708-
rewriter.setInsertionPointToStart(block);
709-
auto replaceOperandWithTypeCast = [&](OpOperand &operand) {
710-
// If the op already has implicit casting semantics for this operand,
711-
// do not fuse.
712-
if (getElementTypeOrSelf(operand.get().getType()) != outElementType) {
713-
return false;
714-
}
715-
auto producer = operand.get().getDefiningOp<linalg::GenericOp>();
716-
if (!producer) {
717-
return false;
718-
}
719-
if (!linalg::isElementwise(producer) ||
720-
producer.getNumDpsInputs() != 1 || producer.getNumDpsInits() != 1) {
721-
return false;
722-
}
702+
OpBuilder::InsertionGuard guard(rewriter);
703+
Block *block = &namedOp->getRegion(0).front();
704+
rewriter.setInsertionPointToStart(block);
705+
bool didChangeOperand =
706+
replaceOperandWithTypeCast(namedOp->getOpOperand(0), outElementType,
707+
namedOp, block, rewriter) ||
708+
replaceOperandWithTypeCast(namedOp->getOpOperand(1), outElementType,
709+
namedOp, block, rewriter);
710+
return success(didChangeOperand);
711+
}
723712

724-
if (!llvm::hasSingleElement(
725-
producer.getBlock()->without_terminator())) {
726-
return false;
727-
}
728-
// We only handle arith.extf here for two reasons:
729-
// 1) This pattern is being applied to convolution/contraction
730-
// interfaces. Extension semantics for integers depend on the named
731-
// op and requires a slightly different pattern.
732-
// 2) Truncating operations like `arith.truncf` should not be fused
733-
// with consumers; it would be preferred to fuse those with
734-
// producers (and the consumer fusion is arguably the less canonical
735-
// form).
736-
if (!llvm::isa<arith::ExtFOp>(
737-
*producer.getBlock()->without_terminator().begin())) {
738-
return false;
739-
}
740-
Type producerElementType = getElementTypeOrSelf(
741-
producer.getDpsInputOperand(0)->get().getType());
742-
int64_t operandNumber = operand.getOperandNumber();
743-
// Set the operand to the linalg op to the smaller one.
744-
namedOp->setOperand(operandNumber, producer->getOperand(0));
745-
746-
// Insert a new block argument into the body of the named op with the
747-
// correct type.
748-
Value blockArg = block->insertArgument(
749-
operandNumber, producerElementType, namedOp.getLoc());
750-
// Create the extf.
751-
auto ext = rewriter.create<arith::ExtFOp>(namedOp.getLoc(),
752-
outElementType, blockArg);
753-
// Replace uses of the old argument with the extended value.
754-
rewriter.replaceAllUsesWith(block->getArgument(operandNumber + 1),
755-
ext.getResult());
756-
// Erase the old argument.
757-
block->eraseArgument(operandNumber + 1);
758-
return true;
759-
};
760-
761-
didChangeOperand = replaceOperandWithTypeCast(namedOp->getOpOperand(0));
762-
didChangeOperand |= replaceOperandWithTypeCast(namedOp->getOpOperand(1));
713+
private:
714+
static bool replaceOperandWithTypeCast(OpOperand &operand,
715+
Type outElementType, OpTy namedOp,
716+
Block *block, RewriterBase &rewriter) {
717+
// If the op already has implicit casting semantics for this operand,
718+
// do not fuse.
719+
if (getElementTypeOrSelf(operand.get().getType()) != outElementType) {
720+
return false;
763721
}
764-
return success(didChangeOperand);
722+
auto producer = operand.get().getDefiningOp<linalg::GenericOp>();
723+
if (!producer) {
724+
return false;
725+
}
726+
if (!linalg::isElementwise(producer) || producer.getNumDpsInputs() != 1 ||
727+
producer.getNumDpsInits() != 1) {
728+
return false;
729+
}
730+
731+
if (!llvm::hasSingleElement(producer.getBlock()->without_terminator())) {
732+
return false;
733+
}
734+
735+
// Note: only extf and extsi are supported
736+
//
737+
// convolution/contraction ops internally use extsi to cast to the correct
738+
// bitwidth.
739+
//
740+
// Truncating operations like `arith.truncf` should not be fused with
741+
// consumers; it would be preferred to fuse those with producers (and the
742+
// consumer fusion is arguably the less canonical form).
743+
Operation &castOp = *producer.getBlock()->without_terminator().begin();
744+
if (!llvm::isa<arith::ExtFOp, arith::ExtSIOp>(castOp)) {
745+
return false;
746+
}
747+
Type producerElementType =
748+
getElementTypeOrSelf(producer.getDpsInputOperand(0)->get().getType());
749+
int64_t operandNumber = operand.getOperandNumber();
750+
// Set the operand to the linalg op to the smaller one.
751+
namedOp->setOperand(operandNumber, producer->getOperand(0));
752+
753+
// Insert a new block argument into the body of the named op with the
754+
// correct type.
755+
Value blockArg = block->insertArgument(operandNumber, producerElementType,
756+
namedOp.getLoc());
757+
// Create the extf/extsi.
758+
IRMapping mapping;
759+
mapping.map(castOp.getOperand(0), blockArg);
760+
Value ext = rewriter.clone(castOp, mapping)->getResult(0);
761+
762+
// Replace uses of the old argument with the extended value.
763+
rewriter.replaceAllUsesWith(block->getArgument(operandNumber + 1), ext);
764+
// Erase the old argument.
765+
block->eraseArgument(operandNumber + 1);
766+
return true;
765767
}
766768
};
767769

compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir

+122
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,125 @@ util.func public @conv_nchw_extf_both(%arg0 : tensor<1x5x10x10xf16>,
538538
// CHECK-SAME: %[[ARG1:.+]]: tensor<5x5x3x3xf16>
539539
// CHECK: %[[RESULT:.+]] = linalg.conv_2d_nchw_fchw {{.*}} ins(%[[ARG0]], %[[ARG1]]
540540
// CHECK: util.return %[[RESULT]]
541+
542+
// -----
543+
544+
util.func public @matmul_extsi(%arg0 : tensor<10x20xi32>,
545+
%arg1 : tensor<20x40xi16>) -> tensor<10x40xi32> {
546+
%0 = tensor.empty() : tensor<20x40xi32>
547+
%1 = linalg.generic {
548+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
549+
iterator_types = ["parallel", "parallel"]}
550+
ins(%arg1 : tensor<20x40xi16>) outs(%0 : tensor<20x40xi32>) {
551+
^bb0(%b0 : i16, %b1 : i32):
552+
%e = arith.extsi %b0 : i16 to i32
553+
linalg.yield %e : i32
554+
} -> tensor<20x40xi32>
555+
%2 = tensor.empty() : tensor<10x40xi32>
556+
%3 = arith.constant 0 : i32
557+
%4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
558+
%5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>)
559+
outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32>
560+
util.return %5 : tensor<10x40xi32>
561+
}
562+
// CHECK-LABEL: util.func public @matmul_extsi
563+
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32>
564+
// CHECK-SAME: %[[ARG1:.+]]: tensor<20x40xi16>
565+
// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]]
566+
// CHECK: util.return %[[RESULT]]
567+
// -----
568+
569+
util.func public @matmul_extsi_a(%arg0 : tensor<10x20xi16>,
570+
%arg1 : tensor<20x40xi32>) -> tensor<10x40xi32> {
571+
%0 = tensor.empty() : tensor<10x20xi32>
572+
%1 = linalg.generic {
573+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
574+
iterator_types = ["parallel", "parallel"]}
575+
ins(%arg0 : tensor<10x20xi16>) outs(%0 : tensor<10x20xi32>) {
576+
^bb0(%b0 : i16, %b1 : i32):
577+
%e = arith.extsi %b0 : i16 to i32
578+
linalg.yield %e : i32
579+
} -> tensor<10x20xi32>
580+
%2 = tensor.empty() : tensor<10x40xi32>
581+
%3 = arith.constant 0 : i32
582+
%4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
583+
%5 = linalg.matmul ins(%1, %arg1 : tensor<10x20xi32>, tensor<20x40xi32>)
584+
outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32>
585+
util.return %5 : tensor<10x40xi32>
586+
}
587+
// CHECK-LABEL: util.func public @matmul_extsi_a
588+
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi16>
589+
// CHECK-SAME: %[[ARG1:.+]]: tensor<20x40xi32>
590+
// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]]
591+
// CHECK: util.return %[[RESULT]]
592+
593+
// -----
594+
595+
util.func public @matmul_extsi_both(%arg0 : tensor<10x20xi16>,
596+
%arg1 : tensor<20x40xi16>) -> tensor<10x40xi32> {
597+
%0 = tensor.empty() : tensor<10x20xi32>
598+
%1 = linalg.generic {
599+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
600+
iterator_types = ["parallel", "parallel"]}
601+
ins(%arg0 : tensor<10x20xi16>) outs(%0 : tensor<10x20xi32>) {
602+
^bb0(%b0 : i16, %b1 : i32):
603+
%e = arith.extsi %b0 : i16 to i32
604+
linalg.yield %e : i32
605+
} -> tensor<10x20xi32>
606+
%2 = tensor.empty() : tensor<20x40xi32>
607+
%3 = linalg.generic {
608+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
609+
iterator_types = ["parallel", "parallel"]}
610+
ins(%arg1 : tensor<20x40xi16>) outs(%2 : tensor<20x40xi32>) {
611+
^bb0(%b2 : i16, %b3 : i32):
612+
%e1 = arith.extsi %b2 : i16 to i32
613+
linalg.yield %e1 : i32
614+
} -> tensor<20x40xi32>
615+
%4 = tensor.empty() : tensor<10x40xi32>
616+
%5 = arith.constant 0 : i32
617+
%6 = linalg.fill ins(%5 : i32) outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32>
618+
%7 = linalg.matmul ins(%1, %3 : tensor<10x20xi32>, tensor<20x40xi32>)
619+
outs(%6 : tensor<10x40xi32>) -> tensor<10x40xi32>
620+
util.return %7 : tensor<10x40xi32>
621+
}
622+
// CHECK-LABEL: util.func public @matmul_extsi_both
623+
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi16>
624+
// CHECK-SAME: %[[ARG1:.+]]: tensor<20x40xi16>
625+
// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]]
626+
// CHECK: util.return %[[RESULT]]
627+
628+
// -----
629+
630+
util.func public @conv_nchw_extsi_both(%arg0 : tensor<1x5x10x10xi16>,
631+
%arg1 : tensor<5x5x3x3xi16>) -> tensor<1x5x8x8xi32> {
632+
%0 = tensor.empty() : tensor<1x5x10x10xi32>
633+
%1 = linalg.generic {
634+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
635+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
636+
ins(%arg0 : tensor<1x5x10x10xi16>) outs(%0 : tensor<1x5x10x10xi32>) {
637+
^bb0(%b0 : i16, %b1 : i32):
638+
%e = arith.extsi %b0 : i16 to i32
639+
linalg.yield %e : i32
640+
} -> tensor<1x5x10x10xi32>
641+
%2 = tensor.empty() : tensor<5x5x3x3xi32>
642+
%3 = linalg.generic {
643+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
644+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
645+
ins(%arg1 : tensor<5x5x3x3xi16>) outs(%2 : tensor<5x5x3x3xi32>) {
646+
^bb0(%b2 : i16, %b3 : i32):
647+
%e1 = arith.extsi %b2 : i16 to i32
648+
linalg.yield %e1 : i32
649+
} -> tensor<5x5x3x3xi32>
650+
%4 = tensor.empty() : tensor<1x5x8x8xi32>
651+
%5 = arith.constant 0 : i32
652+
%6 = linalg.fill ins(%5 : i32) outs(%4 : tensor<1x5x8x8xi32>) -> tensor<1x5x8x8xi32>
653+
%7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
654+
ins(%1, %3 : tensor<1x5x10x10xi32>, tensor<5x5x3x3xi32>)
655+
outs(%6 : tensor<1x5x8x8xi32>) -> tensor<1x5x8x8xi32>
656+
util.return %7 : tensor<1x5x8x8xi32>
657+
}
658+
// CHECK-LABEL: util.func public @conv_nchw_extsi_both
659+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x5x10x10xi16>
660+
// CHECK-SAME: %[[ARG1:.+]]: tensor<5x5x3x3xi16>
661+
// CHECK: %[[RESULT:.+]] = linalg.conv_2d_nchw_fchw {{.*}} ins(%[[ARG0]], %[[ARG1]]
662+
// CHECK: util.return %[[RESULT]]

0 commit comments

Comments
 (0)