Skip to content

Commit

Permalink
Revert "[mlir][Transforms] Dialect conversion: Simplify handling of d…
Browse files Browse the repository at this point in the history
…ropped arguments (llvm#97213)"

This reverts commit bbd4af5.
  • Loading branch information
ScottTodd authored and hanhanW committed Jul 30, 2024
1 parent b4444dc commit 65a00e3
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 57 deletions.
173 changes: 118 additions & 55 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,34 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};

/// This structure contains the information pertaining to an argument that has
/// been converted.
struct ConvertedArgInfo {
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
Value castValue = nullptr)
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}

/// The start index of in the new argument list that contains arguments that
/// replace the original.
unsigned newArgIdx;

/// The number of arguments that replaced the original argument.
unsigned newArgSize;

/// The cast value that was created to cast from the new arguments to the
/// old. This only used if 'newArgSize' > 1.
Value castValue;
};

/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, Block *origBlock,
const TypeConverter *converter)
BlockTypeConversionRewrite(
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
origBlock(origBlock), converter(converter) {}
origBlock(origBlock), argInfo(argInfo), converter(converter) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
Expand All @@ -458,6 +478,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;

/// The conversion information for each of the arguments. The information is
/// std::nullopt if the argument was dropped during conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;

/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
Expand Down Expand Up @@ -666,16 +690,12 @@ class CreateOperationRewrite : public OperationRewrite {
/// The type of materialization.
enum MaterializationKind {
/// This materialization materializes a conversion for an illegal block
/// argument type, to the original one.
/// argument type, to a legal one.
Argument,

/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target,

/// This materialization materializes a conversion from a legal type back to
/// an illegal one.
Source
Target
};

/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
Expand Down Expand Up @@ -715,7 +735,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;
};
} // namespace
Expand Down Expand Up @@ -834,6 +854,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type outputType,
const TypeConverter *converter);

Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
Expand Down Expand Up @@ -963,6 +988,28 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);

// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
// Handle the case of a 1->0 value mapping.
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}

// Otherwise this is a 1->1+ value mapping.
Value castValue = info->castValue;
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");

// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
castValue, origArg.getType()));
}
}
}

void BlockTypeConversionRewrite::rollback() {
Expand All @@ -987,12 +1034,14 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
assert(replacementValue && "replacement value not found");
bool isDroppedArg = replacementValue == origArg;
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(), replacementValue);
builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
Expand All @@ -1003,6 +1052,8 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
if (!isDroppedArg)
diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
Expand Down Expand Up @@ -1288,64 +1339,73 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);

// Remap each of the original arguments as determined by the signature
// conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);

for (unsigned i = 0; i != origArgCount; ++i) {
BlockArgument origArg = block->getArgument(i);
Type origArgType = origArg.getType();

std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
signatureConversion.getInputMapping(i);
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
continue;
}
BlockArgument origArg = block->getArgument(i);

if (Value repl = inputMap->replacementValue) {
// This block argument was dropped and a replacement value was provided.
// If inputMap->replacementValue is not nullptr, then the argument is
// dropped and a replacement value is provided to be the remappedValue.
if (inputMap->replacementValue) {
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, repl);
mapping.map(origArg, inputMap->replacementValue);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}

// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
// dialect conversion. Therefore, we need an argument materialization to
// turn the replacement block arguments into a single SSA value that can be
// used as a replacement.
// Otherwise, this is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
mapping.map(origArg, argMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Value newArg;

// If this is a 1->1 mapping and the types of new and replacement arguments
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
Type legalOutputType;
if (converter)
legalOutputType = converter->convertType(origArgType);
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, targetMat);
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
mapping.map(origArg, newArg);
} else {
// Build argument materialization: new block arguments -> old block
// argument type.
Value argMat = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
mapping.map(origArg, argMat);

// Build target materialization: old block argument type -> legal type.
// Note: This function returns an "empty" type if no valid conversion to
// a legal type exists. In that case, we continue the conversion with the
// original block argument type.
Type legalOutputType = converter->convertType(origArg.getType());
if (legalOutputType && legalOutputType != origArg.getType()) {
newArg = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, newArg);
} else {
newArg = argMat;
}
}

appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}

appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);

// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
Expand Down Expand Up @@ -1377,6 +1437,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Expand Down Expand Up @@ -2793,10 +2860,6 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
Expand All @@ -2809,8 +2872,8 @@ static LogicalResult legalizeUnresolvedMaterialization(

InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
"from ("
<< inputOperands.getTypes() << ") to " << outputType
"from "
<< inputOperands.getTypes() << " to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@


func.func @test_invalid_arg_materialization(
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
%arg0: i16) {
// expected-note@below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}

Expand Down Expand Up @@ -103,8 +104,9 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
^bb0(%arg0: f32):
// expected-note@below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
Expand Down

0 comments on commit 65a00e3

Please sign in to comment.