@@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
7474 LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
7575}
7676
77+ /// Base class that defines BasicPtxBuilderOpInterface.
78+ class NVVM_PTXBuilder_Op<string mnemonic,
79+ list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
80+ LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
81+ }
82+
7783//===----------------------------------------------------------------------===//
7884// NVVM attribute definitions
7985//===----------------------------------------------------------------------===//
@@ -206,21 +212,31 @@ def NVVM_ReduxOp :
206212//===----------------------------------------------------------------------===//
207213
208214/// mbarrier.init instruction with generic pointer type
209- def NVVM_MBarrierInitOp : NVVM_Op <"mbarrier.init">,
210- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
215+ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op <"mbarrier.init">,
216+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate )> {
211217 string llvmBuilder = [{
212218 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
213219 }];
214- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
220+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
221+ let extraClassDeclaration = [{
222+ bool hasIntrinsic() { if(getPredicate()) return false; return true; }
223+ }];
224+ let extraClassDefinition = [{
225+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
226+ }];
215227}
216228
217229/// mbarrier.init instruction with shared pointer type
218- def NVVM_MBarrierInitSharedOp : NVVM_Op <"mbarrier.init.shared">,
219- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
230+ def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op <"mbarrier.init.shared">,
231+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate )> {
220232 string llvmBuilder = [{
221233 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
222234 }];
223- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
235+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
236+ let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
237+ let extraClassDefinition = [{
238+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
239+ }];
224240}
225241
226242def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
275291 let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
276292}
277293
278- def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
279- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
280- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
281- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
294+ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
295+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
296+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
282297 let extraClassDefinition = [{
283298 std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
284299 }];
285300}
286301
287- def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
288- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
289- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
290- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
302+ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
303+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
304+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
291305 let extraClassDefinition = [{
292306 std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
293307 }];
294308}
295309
296- def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
297- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
310+ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
298311 Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
299312 let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
300313 let extraClassDefinition = [{
@@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
313326 }];
314327}
315328
316- def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
317- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
329+ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
318330 Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
319331 let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
320332 let extraClassDefinition = [{
@@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
488500
489501def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
490502
491- def NVVM_CpAsyncOp : NVVM_Op <"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>] >,
503+ def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op <"cp.async.shared.global">,
492504 Arguments<(ins LLVM_i8Ptr_shared:$dst,
493505 LLVM_i8Ptr_global:$src,
494506 I32Attr:$size,
@@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
13591371// NVVM TMA Ops
13601372//===----------------------------------------------------------------------===//
13611373
1362- def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1374+ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
1375+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
1376+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1377+ AttrSizedOperandSegments]>,
13631378 Arguments<(ins LLVM_i64ptr_shared:$dstMem,
13641379 LLVM_i64ptr_any:$tmaDescriptor,
13651380 LLVM_i64ptr_shared:$mbar,
1366- Variadic<I32>:$coordinates)> {
1367- let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1381+ Variadic<I32>:$coordinates,
1382+ PtxPredicate:$predicate)> {
1383+ let assemblyFormat = [{
1384+ $dstMem `,`
1385+ $tmaDescriptor `,`
1386+ $mbar `,`
1387+ `box` `[`$coordinates `]`
1388+ (`,` `predicate` `=` $predicate^)?
1389+ attr-dict `:` type(operands)
1390+ }];
1391+
13681392 let extraClassDefinition = [{
13691393 std::string $cppClass::getPtx() {
13701394 int dim = getCoordinates().size();
@@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
13821406 let hasVerifier = 1;
13831407}
13841408
1385- def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1409+ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
1410+ NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
1411+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1412+ AttrSizedOperandSegments]>,
13861413 Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
13871414 LLVM_i64ptr_shared:$srcMem,
1388- Variadic<I32>:$coordinates)> {
1389- let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1415+ Variadic<I32>:$coordinates,
1416+ PtxPredicate:$predicate)> {
1417+ let assemblyFormat = [{
1418+ $tmaDescriptor `,`
1419+ $srcMem `,`
1420+ `box` `[`$coordinates `]`
1421+ (`,` `predicate` `=` $predicate^)?
1422+ attr-dict `:` type(operands)
1423+ }];
13901424 let extraClassDefinition = [{
13911425 std::string $cppClass::getPtx() {
13921426 int dim = getCoordinates().size();
@@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
14081442// NVVM Wgmma Ops
14091443//===----------------------------------------------------------------------===//
14101444
1411- def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
1412- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
1445+ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
14131446 let arguments = (ins);
14141447 let description = [{
14151448 Enforce an ordering of register accesses between warpgroup level matrix
@@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
14231456 }];
14241457}
14251458
1426- def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
1427- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1459+ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
14281460 Arguments<(ins )> {
14291461 let assemblyFormat = "attr-dict";
14301462 let description = [{
@@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
14371469 }];
14381470}
14391471
1440- def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
1441- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
1472+ def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
14421473 let arguments = (ins I32Attr:$group);
14431474 let assemblyFormat = "attr-dict $group";
14441475 let description = [{
0 commit comments