Skip to content

Commit

Permalink
[MLIR][OpenMP][OMPIRBuilder] Error propagation across callbacks
Browse files Browse the repository at this point in the history
This is a small proof of concept showing an approach to communicate errors
between MLIR to LLVM IR translation of the OpenMP dialect and the OMPIRBuilder.
It only implements the approach for a single case, so it doesn't compile or
run, since it's only intended to show how it could look like and discuss it
before investing too much effort on a full implementation.

The main idea is to use `llvm::Error` objects returned by callbacks passed to
`OMPIRBuilder` codegen functions that they can then check and forward back to
the caller to avoid continuing after an error has been hit. The caller then
emits an MLIR error diagnostic based on that and stops the translation process.

This should prevent encountering any unsupported operations or arguments, or
any other unexpected error from resulting in a compiler crash. Instead, a
descriptive error message is presented to users.
  • Loading branch information
skatrak committed Oct 16, 2024
1 parent 15d8576 commit 82d2acd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
24 changes: 15 additions & 9 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,15 +589,19 @@ class OpenMPIRBuilder {
/// not be split.
/// \param CodeGenIP is the insertion point at which the body code should be
/// placed.
///
/// \return an error, if any were triggered during execution.
using BodyGenCallbackTy =
function_ref<void(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
function_ref<Error(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;

// This is created primarily for sections construct as llvm::function_ref
// (BodyGenCallbackTy) is not storable (as described in the comments of
// function_ref class - function_ref contains non-ownable reference
// to the callable.
///
/// \return an error, if any were triggered during execution.
using StorableBodyGenCallbackTy =
std::function<void(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
std::function<Error(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;

/// Callback type for loop body code generation.
///
Expand All @@ -607,8 +611,10 @@ class OpenMPIRBuilder {
/// terminated with an unconditional branch to the loop
/// latch.
/// \param IndVar is the induction variable usable at the insertion point.
///
/// \return an error, if any were triggered during execution.
using LoopBodyGenCallbackTy =
function_ref<void(InsertPointTy CodeGenIP, Value *IndVar)>;
function_ref<Error(InsertPointTy CodeGenIP, Value *IndVar)>;

/// Callback type for variable privatization (think copy & default
/// constructor).
Expand All @@ -626,9 +632,9 @@ class OpenMPIRBuilder {
/// \param ReplVal The replacement value, thus a copy or new created version
/// of \p Inner.
///
/// \returns The new insertion point where code generation continues and
/// \p ReplVal the replacement value.
using PrivatizeCallbackTy = function_ref<InsertPointTy(
/// \returns The new insertion point where code generation continues or an
/// error, and \p ReplVal the replacement value.
using PrivatizeCallbackTy = function_ref<Expected<InsertPointTy>(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &Original,
Value &Inner, Value *&ReplVal)>;

Expand Down Expand Up @@ -1262,9 +1268,9 @@ class OpenMPIRBuilder {
/// \param Loc The location where the taskgroup construct was encountered.
/// \param AllocaIP The insertion point to be used for alloca instructions.
/// \param BodyGenCB Callback that will generate the region code.
InsertPointTy createTaskgroup(const LocationDescription &Loc,
InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB);
Expected<InsertPointTy> createTaskgroup(const LocationDescription &Loc,
InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB);

using FileIdentifierInfoCallbackTy =
std::function<std::tuple<std::string, uint64_t>()>;
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
return Builder.saveIP();
}

OpenMPIRBuilder::InsertPointTy
Expected<OpenMPIRBuilder::InsertPointTy>
OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB) {
Expand All @@ -2066,7 +2066,8 @@ OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});

BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
BodyGenCB(AllocaIP, Builder.saveIP());
if (auto Err = BodyGenCB(AllocaIP, Builder.saveIP()))
return std::move(Err);

Builder.SetInsertPoint(TaskgroupExitBB);
// Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
Expand Down
32 changes: 18 additions & 14 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// region, and a branch from any block with an successor-less OpenMP terminator
/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
/// of the continuation block if provided.
static llvm::BasicBlock *convertOmpOpRegions(
static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
Expand Down Expand Up @@ -215,10 +215,8 @@ static llvm::BasicBlock *convertOmpOpRegions(

llvm::IRBuilderBase::InsertPointGuard guard(builder);
if (failed(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
bodyGenStatus = failure();
return continuationBlock;
}
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::createStringError("failed region translation");

// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
Expand Down Expand Up @@ -1145,20 +1143,26 @@ static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty())
return tgOp.emitError("unhandled clauses for translation to LLVM IR");
}

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);
convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
moduleTranslation, bodyGenStatus);
return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
builder, moduleTranslation)
.takeError();
};

InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup(
ompLoc, allocaIP, bodyCB));
return bodyGenStatus;
auto result = moduleTranslation.getOpenMPBuilder()->createTaskgroup(
ompLoc, allocaIP, bodyCB);

if (auto error = result.takeError())
return tgOp.emitError(llvm::toString(std::move(error)));

builder.restoreIP(*result);
return success();
}

static LogicalResult
Expand Down

0 comments on commit 82d2acd

Please sign in to comment.