Skip to content

Commit 2b65afc

Browse files
committed
[MLIR][NVVM] Update mbarrier.init Op
This patch updates the mbarrier.init/inval Ops to use the AnyTypeOf[] construct for their `addr` argument. This enables us to have a single Op that can take a pointer in either generic or shared memory space and generate the right intrinsics during the lowering. * Existing tests are updated. * Locally verified that there are no new regressions in the Integration tests. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent edab721 commit 2b65afc

File tree

6 files changed

+82
-59
lines changed

6 files changed

+82
-59
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,8 @@ def NVVM_PMEventOp : NVVM_PTXBuilder_Op<"pmevent">,
579579

580580
/// mbarrier.init instruction with generic pointer type
581581
def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
582-
Arguments<(ins LLVM_AnyPointer:$addr, I32:$count, PtxPredicate:$predicate)> {
582+
Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
583+
I32:$count, PtxPredicate:$predicate)> {
583584
let summary = "MBarrier Initialization Op";
584585
let description = [{
585586
The `nvvm.mbarrier.init` operation initializes an *mbarrier object* at the specified
@@ -592,48 +593,35 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
592593
- Transaction count (tx-count): 0
593594

594595
The operation takes the following operands:
595-
- `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic
596-
addressing, but the address must still be in the shared memory space.
596+
- `addr`: A pointer to the memory location of the *mbarrier object*. The `addr`
597+
must be a pointer to generic or shared::cta memory. When it is generic, the
598+
underlying address must be within the shared::cta memory space; otherwise
599+
the behavior is undefined.
597600
- `count`: Integer specifying the number of threads that will participate in barrier
598601
synchronization. Must be in the range [1, 2²⁰ - 1].
599602
- `predicate`: Optional predicate for conditional execution.
600603

601604
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init)
602605
}];
603-
string llvmBuilder = [{
604-
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
605-
}];
606606
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
607+
607608
let extraClassDeclaration = [{
608609
bool hasIntrinsic() { if(getPredicate()) return false; return true; }
609-
}];
610-
let extraClassDefinition = [{
611-
std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
612-
}];
613-
}
614610

615-
/// mbarrier.init instruction with shared pointer type
616-
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
617-
Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
618-
let summary = "Shared MBarrier Initialization Op";
619-
let description = [{
620-
This Op is the same as `nvvm.mbarrier.init` except that the *mbarrier object*
621-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
622-
623-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init)
611+
static mlir::NVVM::IDArgPair
612+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
613+
llvm::IRBuilderBase& builder);
624614
}];
615+
625616
string llvmBuilder = [{
626-
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
627-
}];
628-
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
629-
let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
630-
let extraClassDefinition = [{
631-
std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
617+
auto [id, args] = NVVM::MBarrierInitOp::getIntrinsicIDAndArgs(
618+
*op, moduleTranslation, builder);
619+
createIntrinsicCall(builder, id, args);
632620
}];
633621
}
634622

635623
def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
636-
Arguments<(ins LLVM_AnyPointer:$addr)> {
624+
Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr)> {
637625
let summary = "MBarrier Invalidation Operation";
638626
let description = [{
639627
The `nvvm.mbarrier.inval` operation invalidates an *mbarrier object* at the
@@ -644,30 +632,27 @@ def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
644632
It is undefined behavior if the *mbarrier object* is already invalid.
645633

646634
The operation takes the following operand:
647-
- `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic
648-
addressing, but the address must still be in the shared memory space.
635+
- `addr`: A pointer to the memory location of the *mbarrier object*. The `addr`
636+
must be a pointer to generic or shared::cta memory. When it is generic, the
637+
underlying address must be within the shared::cta memory space; otherwise
638+
the behavior is undefined.
649639

650640
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval)
651641
}];
652-
string llvmBuilder = [{
653-
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval, {$addr});
654-
}];
655-
let assemblyFormat = "$addr attr-dict `:` type(operands)";
656-
}
657642

658-
def NVVM_MBarrierInvalSharedOp : NVVM_Op<"mbarrier.inval.shared">,
659-
Arguments<(ins LLVM_PointerShared:$addr)> {
660-
let summary = "Shared MBarrier Invalidation Operation";
661-
let description = [{
662-
This Op is the same as `nvvm.mbarrier.inval` except that the *mbarrier object*
663-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
643+
let assemblyFormat = "$addr attr-dict `:` type(operands)";
664644

665-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval)
645+
let extraClassDeclaration = [{
646+
static mlir::NVVM::IDArgPair
647+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
648+
llvm::IRBuilderBase& builder);
666649
}];
650+
667651
string llvmBuilder = [{
668-
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval_shared, {$addr});
652+
auto [id, args] = NVVM::MBarrierInvalOp::getIntrinsicIDAndArgs(
653+
*op, moduleTranslation, builder);
654+
createIntrinsicCall(builder, id, args);
669655
}];
670-
let assemblyFormat = "$addr attr-dict `:` type(operands)";
671656
}
672657

673658
def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">,

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering
846846
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
847847
adaptor.getMbarId(), rewriter);
848848
Value count = truncToI32(b, adaptor.getCount());
849-
if (isMbarrierShared(mbarrierType)) {
850-
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
851-
op, barrier, count, adaptor.getPredicate());
852-
} else {
853-
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
854-
adaptor.getPredicate());
855-
}
849+
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
850+
adaptor.getPredicate());
856851
return success();
857852
}
858853
};

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,10 +1607,53 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
16071607
mt.mapValue(thisOp.getRes()) = smemDesc;
16081608
}
16091609

1610+
//===----------------------------------------------------------------------===//
1611+
// getPtx methods
1612+
//===----------------------------------------------------------------------===//
1613+
1614+
std::string NVVM::MBarrierInitOp::getPtx() {
1615+
unsigned addressSpace =
1616+
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1617+
return (addressSpace == NVVMMemorySpace::Shared)
1618+
? std::string("mbarrier.init.shared.b64 [%0], %1;")
1619+
: std::string("mbarrier.init.b64 [%0], %1;");
1620+
}
1621+
16101622
//===----------------------------------------------------------------------===//
16111623
// getIntrinsicID/getIntrinsicIDAndArgs methods
16121624
//===----------------------------------------------------------------------===//
16131625

1626+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
1627+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1628+
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1629+
unsigned addressSpace =
1630+
llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1631+
.getAddressSpace();
1632+
llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
1633+
? llvm::Intrinsic::nvvm_mbarrier_init_shared
1634+
: llvm::Intrinsic::nvvm_mbarrier_init;
1635+
1636+
// Fill the Intrinsic Args
1637+
llvm::SmallVector<llvm::Value *> args;
1638+
args.push_back(mt.lookupValue(thisOp.getAddr()));
1639+
args.push_back(mt.lookupValue(thisOp.getCount()));
1640+
1641+
return {id, std::move(args)};
1642+
}
1643+
1644+
mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
1645+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1646+
auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1647+
unsigned addressSpace =
1648+
llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1649+
.getAddressSpace();
1650+
llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
1651+
? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1652+
: llvm::Intrinsic::nvvm_mbarrier_inval;
1653+
1654+
return {id, {mt.lookupValue(thisOp.getAddr())}};
1655+
}
1656+
16141657
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
16151658
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
16161659

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ func.func @mbarrier() {
486486
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
487487
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
488488
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
489-
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
489+
// CHECK: nvvm.mbarrier.init %[[barPtr]]
490490
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
491491

492492
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
@@ -516,7 +516,7 @@ func.func @mbarrier_nocomplete() {
516516
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
517517
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
518518
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
519-
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
519+
// CHECK: nvvm.mbarrier.init %[[barPtr]]
520520
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
521521

522522
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
@@ -592,7 +592,7 @@ func.func @mbarrier_txcount() {
592592
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
593593
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
594594
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
595-
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
595+
// CHECK: nvvm.mbarrier.init %[[barPtr]]
596596
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
597597

598598
%tidxreg = nvvm.read.ptx.sreg.tid.x : i32
@@ -643,7 +643,7 @@ func.func @mbarrier_txcount_pred() {
643643
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
644644
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
645645
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
646-
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
646+
// CHECK: nvvm.mbarrier.init %[[barPtr]], {{.*}}, predicate = %[[P]]
647647
nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
648648

649649
%txcount = arith.constant 256 : index

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK-LABEL: @init_mbarrier
99
llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
1010
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b"
11-
nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1
11+
nvvm.mbarrier.init %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1
1212
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
1313
nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
1414
llvm.return

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@ llvm.func private @mbarrier_init_generic(%barrier: !llvm.ptr) {
419419

420420
llvm.func private @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
421421
%count = nvvm.read.ptx.sreg.ntid.x : i32
422-
// CHECK: nvvm.mbarrier.init.shared %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32
423-
nvvm.mbarrier.init.shared %barrier, %count : !llvm.ptr<3>, i32
422+
// CHECK: nvvm.mbarrier.init %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32
423+
nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32
424424
llvm.return
425425
}
426426

@@ -433,8 +433,8 @@ llvm.func private @mbarrier_inval_generic(%barrier: !llvm.ptr) {
433433

434434

435435
llvm.func private @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
436-
// CHECK: nvvm.mbarrier.inval.shared %{{.*}} : !llvm.ptr<3>
437-
nvvm.mbarrier.inval.shared %barrier : !llvm.ptr<3>
436+
// CHECK: nvvm.mbarrier.inval %{{.*}} : !llvm.ptr<3>
437+
nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
438438
llvm.return
439439
}
440440

0 commit comments

Comments
 (0)