Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -731,4 +731,56 @@ def LLVM_Prefetch : LLVM_ZeroResultOp<"intr.prefetch">,
}];
}

def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>;
def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>;
def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>;
def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3>;
def AtomicBinOpNand : I64EnumAttrCase<"nand", 4>;
def AtomicBinOpOr : I64EnumAttrCase<"_or", 5>;
def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6>;
def AtomicBinOpMax : I64EnumAttrCase<"max", 7>;
def AtomicBinOpMin : I64EnumAttrCase<"min", 8>;
def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9>;
def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10>;
def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11>;
def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12>;
def AtomicBinOp : I64EnumAttr<
"AtomicBinOp",
"llvm.atomicrmw binary operations",
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
AtomicBinOpFSub]> {
let cppNamespace = "::mlir::LLVM";
}

def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0>;
def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1>;
def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2>;
def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 4>;
def AtomicOrderingRelease : I64EnumAttrCase<"release", 5>;
def AtomicOrderingAcquireRelease : I64EnumAttrCase<"acq_rel", 6>;
def AtomicOrderingSequentiallyConsistent : I64EnumAttrCase<"seq_cst", 7>;
def AtomicOrdering : I64EnumAttr<
"AtomicOrdering",
"Atomic ordering for LLVM's memory model",
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcquireRelease,
AtomicOrderingSequentiallyConsistent]> {
let cppNamespace = "::mlir::LLVM";
}

def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val,
AtomicOrdering:$ordering)>,
Results<(outs LLVM_Type:$res)> {
let llvmBuilder = [{
$res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val,
getLLVMAtomicOrdering($ordering));
}];
let parser = [{ return parseAtomicRMWOp(parser, result); }];
let printer = [{ printAtomicRMWOp(p, *this); }];
let verifier = "return ::verify(*this);";
}

#endif // LLVMIR_OPS
101 changes: 101 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,107 @@ static LogicalResult verify(LLVM::NullOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// Printer, parser and verifier for LLVM::AtomicRMWOp.
//===----------------------------------------------------------------------===//

static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
p << op.getOperationName() << " ";
p << '"' << stringifyAtomicBinOp(op.bin_op()) << "\" ";
p << '"' << stringifyAtomicOrdering(op.ordering()) << "\" ";
p << op.ptr() << ", " << op.val();
p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
p << " : (" << op.ptr().getType() << ", " << op.val().getType() << ") -> "
<< op.res().getType();
}

// <operation> ::= `llvm.atomicrmw` string-literal string-literal
// ssa-use `,` ssa-use attribute-dict? `:` type
static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
Type type;
StringAttr binOp, ordering;
llvm::SMLoc binOpLoc, orderingLoc, trailingTypeLoc;
OpAsmParser::OperandType ptr, val;
if (parser.getCurrentLocation(&binOpLoc) ||
parser.parseAttribute(binOp, "bin_op", result.attributes) ||
parser.getCurrentLocation(&orderingLoc) ||
parser.parseAttribute(ordering, "ordering", result.attributes) ||
parser.parseOperand(ptr) || parser.parseComma() ||
parser.parseOperand(val) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();

// Extract the result type from the trailing function type.
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType || funcType.getNumInputs() != 2 ||
funcType.getNumResults() != 1)
return parser.emitError(
trailingTypeLoc,
"expected trailing function type with two arguments and one result");

if (parser.resolveOperand(ptr, funcType.getInput(0), result.operands) ||
parser.resolveOperand(val, funcType.getInput(1), result.operands))
return failure();

// Replace the string attribute `bin_op` with an integer attribute.
auto binOpKind = symbolizeAtomicBinOp(binOp.getValue());
if (!binOpKind) {
return parser.emitError(binOpLoc)
<< "'" << binOp.getValue()
<< "' is an incorrect value of the 'bin_op' attribute";
}

auto binOpValue = static_cast<int64_t>(binOpKind.getValue());
auto binOpAttr = parser.getBuilder().getI64IntegerAttr(binOpValue);
result.attributes[0].second = binOpAttr;

// Replace the string attribute `ordering` with an integer attribute.
auto orderingKind = symbolizeAtomicOrdering(ordering.getValue());
if (!orderingKind) {
return parser.emitError(orderingLoc)
<< "'" << ordering.getValue()
<< "' is an incorrect value of the 'ordering' attribute";
}

auto orderingValue = static_cast<int64_t>(orderingKind.getValue());
auto orderingAttr = parser.getBuilder().getI64IntegerAttr(orderingValue);
result.attributes[1].second = orderingAttr;

result.addTypes(funcType.getResults());
return success();
}

static LogicalResult verify(AtomicRMWOp op) {
auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
if (!ptrType.isPointerTy())
return op.emitOpError("expected LLVM IR pointer type for operand #0");
auto valType = op.val().getType().cast<LLVM::LLVMType>();
if (valType != ptrType.getPointerElementTy())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = op.res().getType().cast<LLVM::LLVMType>();
if (resType != valType)
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
if (!valType.getUnderlyingType()->isFloatingPointTy())
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.bin_op() == AtomicBinOp::xchg) {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
!valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
!valType.getUnderlyingType()->isHalfTy() && !valType.isFloatTy() &&
!valType.isDoubleTy())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
!valType.isIntegerTy(32) && !valType.isIntegerTy(64))
return op.emitOpError("expected LLVM IR integer type");
}
return success();
}

//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
52 changes: 52 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,58 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
llvm_unreachable("incorrect comparison predicate");
}

static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
switch (op) {
case LLVM::AtomicBinOp::xchg:
return llvm::AtomicRMWInst::BinOp::Xchg;
case LLVM::AtomicBinOp::add:
return llvm::AtomicRMWInst::BinOp::Add;
case LLVM::AtomicBinOp::sub:
return llvm::AtomicRMWInst::BinOp::Sub;
case LLVM::AtomicBinOp::_and:
return llvm::AtomicRMWInst::BinOp::And;
case LLVM::AtomicBinOp::nand:
return llvm::AtomicRMWInst::BinOp::Nand;
case LLVM::AtomicBinOp::_or:
return llvm::AtomicRMWInst::BinOp::Or;
case LLVM::AtomicBinOp::_xor:
return llvm::AtomicRMWInst::BinOp::Xor;
case LLVM::AtomicBinOp::max:
return llvm::AtomicRMWInst::BinOp::Max;
case LLVM::AtomicBinOp::min:
return llvm::AtomicRMWInst::BinOp::Min;
case LLVM::AtomicBinOp::umax:
return llvm::AtomicRMWInst::BinOp::UMax;
case LLVM::AtomicBinOp::umin:
return llvm::AtomicRMWInst::BinOp::UMin;
case LLVM::AtomicBinOp::fadd:
return llvm::AtomicRMWInst::BinOp::FAdd;
case LLVM::AtomicBinOp::fsub:
return llvm::AtomicRMWInst::BinOp::FSub;
}
llvm_unreachable("incorrect atomic binary operator");
}

static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
switch (ordering) {
case LLVM::AtomicOrdering::not_atomic:
return llvm::AtomicOrdering::NotAtomic;
case LLVM::AtomicOrdering::unordered:
return llvm::AtomicOrdering::Unordered;
case LLVM::AtomicOrdering::monotonic:
return llvm::AtomicOrdering::Monotonic;
case LLVM::AtomicOrdering::acquire:
return llvm::AtomicOrdering::Acquire;
case LLVM::AtomicOrdering::release:
return llvm::AtomicOrdering::Release;
case LLVM::AtomicOrdering::acq_rel:
return llvm::AtomicOrdering::AcquireRelease;
case LLVM::AtomicOrdering::seq_cst:
return llvm::AtomicOrdering::SequentiallyConsistent;
}
llvm_unreachable("incorrect atomic ordering");
}

/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`. LLVM IR Builder does not have a generic interface so
/// this has to be a long chain of `if`s calling different functions with a
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,51 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
}

// -----
// CHECK-LABEL: @atomicrmw_expected_ptr
func @atomicrmw_expected_ptr(%f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR pointer type for operand #0}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32, %f32 : (!llvm.float, !llvm.float) -> !llvm.float
llvm.return
}

// -----
// CHECK-LABEL: @atomicrmw_mismatched_operands
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %i32 : !llvm.i32) {
// expected-error@+1 {{expected LLVM IR element type for operand #0 to match type for operand #1}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %i32 : (!llvm<"float*">, !llvm.i32) -> !llvm.float
llvm.return
}

// -----
// CHECK-LABEL: @atomicrmw_mismatched_result
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR result type to match type for operand #1}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.i32
llvm.return
}

// -----
// CHECK-LABEL: @atomicrmw_expected_float
func @atomicrmw_expected_float(%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) {
// expected-error@+1 {{expected LLVM IR floating point type}}
%0 = llvm.atomicrmw "fadd" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
llvm.return
}

// -----
// CHECK-LABEL: @atomicrmw_unexpected_xchg_type
func @atomicrmw_xchg_type(%i1_ptr : !llvm<"i1*">, %i1 : !llvm.i1) {
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw "xchg" "unordered" %i1_ptr, %i1 : (!llvm<"i1*">, !llvm.i1) -> !llvm.i1
llvm.return
}

// -----
// CHECK-LABEL: @atomicrmw_expected_int
func @atomicrmw_expected_int(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR integer type}}
%0 = llvm.atomicrmw "max" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
llvm.return
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,10 @@ func @null() {
%1 = llvm.mlir.null : !llvm<"{void(i32, void()*)*, i64}*">
llvm.return
}

// CHECK-LABEL: @atomics
func @atomics(%arg0 : !llvm<"float*">, %arg1 : !llvm.float) {
// CHECK: llvm.atomicrmw "fadd" "unordered" %{{.*}}, %{{.*}} : (!llvm<"float*">, !llvm.float) -> !llvm.float
%0 = llvm.atomicrmw "fadd" "unordered" %arg0, %arg1 : (!llvm<"float*">, !llvm.float) -> !llvm.float
llvm.return
}
33 changes: 33 additions & 0 deletions mlir/test/Target/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1039,3 +1039,36 @@ llvm.func @null() -> !llvm<"i32*"> {
// CHECK: ret i32* null
llvm.return %0 : !llvm<"i32*">
}

// CHECK-LABEL: @atomics
llvm.func @atomics(
%f32_ptr : !llvm<"float*">, %f32 : !llvm.float,
%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) -> !llvm.float {
// CHECK: atomicrmw fadd float* %{{.*}}, float %{{.*}} unordered
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw fsub float* %{{.*}}, float %{{.*}} unordered
%1 = llvm.atomicrmw "fsub" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw xchg float* %{{.*}}, float %{{.*}} monotonic
%2 = llvm.atomicrmw "xchg" "monotonic" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw add i32* %{{.*}}, i32 %{{.*}} acquire
%3 = llvm.atomicrmw "add" "acquire" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw sub i32* %{{.*}}, i32 %{{.*}} release
%4 = llvm.atomicrmw "sub" "release" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw and i32* %{{.*}}, i32 %{{.*}} acq_rel
%5 = llvm.atomicrmw "_and" "acq_rel" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw nand i32* %{{.*}}, i32 %{{.*}} seq_cst
%6 = llvm.atomicrmw "nand" "seq_cst" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw or i32* %{{.*}}, i32 %{{.*}} unordered
%7 = llvm.atomicrmw "_or" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw xor i32* %{{.*}}, i32 %{{.*}} unordered
%8 = llvm.atomicrmw "_xor" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw max i32* %{{.*}}, i32 %{{.*}} unordered
%9 = llvm.atomicrmw "max" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw min i32* %{{.*}}, i32 %{{.*}} unordered
%10 = llvm.atomicrmw "min" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw umax i32* %{{.*}}, i32 %{{.*}} unordered
%11 = llvm.atomicrmw "umax" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw umin i32* %{{.*}}, i32 %{{.*}} unordered
%12 = llvm.atomicrmw "umin" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
llvm.return %0 : !llvm.float
}