Skip to content

Commit

Permalink
Do not use replaceUsesOfBlockArgument
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-springer committed Aug 29, 2024
1 parent e6ff0b9 commit 28f0aea
Showing 1 changed file with 9 additions and 57 deletions.
66 changes: 9 additions & 57 deletions compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,29 +229,13 @@ struct GenericOpTypePropagation
signatureConverter.addInputs(index, legalizedArgType.value());
}
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);
signatureConverter, getTypeConverter());

// 6. Introduce scalar conversion operations to convert back to the
// original scalar type.
{
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = modifiedOp.getBlock();
for (auto modifiedOperandIndex : modifiedOperandIndex) {
OpOperand *modifiedOpOperand =
&modifiedOp->getOpOperand(modifiedOperandIndex);
BlockArgument source =
modifiedOp.getMatchingBlockArgument(modifiedOpOperand);
Type destType = getElementTypeOrSelf(
genericOp.getOperand(modifiedOperandIndex).getType());

// 6a. If the value of the argument is used the argument is in the
// legalized type. Convert it to a value that is in the original
// element type for replacement of all uses in the block.
rewriter.setInsertionPointToStart(entryBlock);
Value replacement =
convertElementType(rewriter, source.getLoc(), destType, source);
rewriter.replaceUsesOfBlockArgument(source, replacement);
}

// 6b. If any of the operands modified were outputs, the yield values
// need to be modified as well.
Expand Down Expand Up @@ -372,27 +356,13 @@ struct IREELinalgExtScatterTypePropagation
signatureConverter.addInputs(0, legalizedArgType.value());
signatureConverter.addInputs(1, legalizedArgType.value());
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);
signatureConverter, getTypeConverter());

{
// Introduce scalar conversion operations to convert back to the original
// scalar type.
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
BlockArgument inputArg = entryBlock->getArgument(0);
BlockArgument outputArg = entryBlock->getArgument(1);

auto destType = getElementTypeOrSelf(inputType);
rewriter.setInsertionPointToStart(entryBlock);

Value replacementInput =
convertElementType(rewriter, inputArg.getLoc(), destType, inputArg);
rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(0),
replacementInput);
Value replacementOutput =
convertElementType(rewriter, outputArg.getLoc(), destType, outputArg);
rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(1),
replacementOutput);

// If the output is of an illegal type, the yield value needs to be
// modified
Expand Down Expand Up @@ -449,31 +419,7 @@ struct IREELinalgExtSortTypePropagation
signatureConverter.addInputs(index, legalizedArgType.value());
}
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);

{
// Introduce scalar conversion operations to convert back to the original
// scalar type.
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
for (auto [index, operand] : llvm::enumerate(sortOp->getOpOperands())) {
BlockArgument firstInputArg = entryBlock->getArgument(index * 2);
BlockArgument secondInputArg = entryBlock->getArgument(index * 2 + 1);

auto destType = getElementTypeOrSelf(operand.get().getType());
rewriter.setInsertionPointToStart(entryBlock);
if (destType != getElementTypeOrSelf(legalizedResultTypes[index])) {
Value replacementFirstInput = convertElementType(
rewriter, firstInputArg.getLoc(), destType, firstInputArg);
rewriter.replaceUsesOfBlockArgument(firstInputArg,
replacementFirstInput);
Value replacementSecondInput = convertElementType(
rewriter, secondInputArg.getLoc(), destType, secondInputArg);
rewriter.replaceUsesOfBlockArgument(secondInputArg,
replacementSecondInput);
}
}
}
signatureConverter, getTypeConverter());
rewriter.replaceOp(sortOp, modifiedOp->getResults());
return success();
}
Expand Down Expand Up @@ -580,6 +526,12 @@ struct TypePropagationPass final
RewritePatternSet patterns(context);

TypePropagationTypeConverter typeConverter;
typeConverter.addArgumentMaterialization(
[&](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
assert(inputs.size() == 1 && "expected exactly one input");
return convertElementType(builder, loc, type, inputs[0]);
});

patterns.insert<
ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
ForwardSourceType<arith::TruncIOp>, GenericOpTypePropagation,
Expand Down

0 comments on commit 28f0aea

Please sign in to comment.