1+ //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ //===----------------------------------------------------------------------===//
8+ //
9+ // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10+ // automatically. It is used by NVVM to LLVM pass.
11+ //
12+ //===----------------------------------------------------------------------===//
13+
14+ #ifndef BASICPTXBUILDER_OP_INTERFACE
15+ #define BASICPTXBUILDER_OP_INTERFACE
16+
17+ include "mlir/IR/EnumAttr.td"
18+ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
19+ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
20+
21+ //===----------------------------------------------------------------------===//
22+ // Basic PTX Builder Interface
23+ //===----------------------------------------------------------------------===//
24+
25+ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
26+ let description = [{
27+ This interface is used to generate inline assembly with PTX for basic
28+ operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
29+ NVVM Ops that implement this interface to PTX (parallel thread execution)
30+ using inline assembly Ops. Interface methods play a crucial role in this
31+ lowering process.
32+
33+ Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
34+ ```tablegen
35+ def NVVM_SpecialOp : NVVM_Op<"special.op",
36+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
37+ Results<(outs LLVM_Type:$res)>,
38+ Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
39+ ...
40+ let extraClassDefinition = [{
41+ std::string $cppClass::getPtx() {
42+ return std::string("special.op %0, %1, %2;");
43+ }
44+ } ];
45+ ```
46+
47+ In the above NVVM Op example:
48+ ```mlir
49+ %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
50+ ```
51+
52+ The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
53+ The order of arguments is retained, and the read and write modifiers are
54+ set based on the input and result types:
55+ ```mlir
56+ %0 = llvm.inline_asm
57+ has_side_effects
58+ asm_dialect =
59+ att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
60+ : (!llvm.ptr, i32) -> i32
61+ ```
62+ }];
63+ let cppNamespace = "::mlir::NVVM";
64+ let methods = [
65+ InterfaceMethod<
66+ /*desc=*/[{ Returns PTX assembly with operand number. }],
67+ /*retType=*/"std::string",
68+ /*methodName=*/"getPtx"
69+ >,
70+ InterfaceMethod<
71+ /*desc=*/[{
72+ This function indicates whether the operation is supported by LLVM
73+ intrinsics. It's particularly useful for operations that have
74+ specific cases with LLVM intrinsic support.
75+ }],
76+ /*retType=*/"bool",
77+ /*methodName=*/"hasIntrinsic",
78+ /*args=*/(ins),
79+ /*methodBody=*/"",
80+ /*defaultImplementation=*/"return false;"
81+ >,
82+ InterfaceMethod<
83+ /*desc=*/[{Return whether the operation has memory side effects.}],
84+ /*retType=*/"bool",
85+ /*methodName=*/"hasSideEffect",
86+ /*args=*/(ins),
87+ /*methodBody=*/"",
88+ /*defaultImplementation=*/"return true;"
89+ >,
90+
91+ InterfaceMethod<
92+ /*desc=*/[{Helper function to generate i32 constant value.}],
93+ /*retType=*/"::mlir::Value",
94+ /*methodName=*/"makeConstantI32",
95+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
96+ /*methodBody=*/"",
97+ /*defaultImpl=*/ [{
98+ mlir::Operation* op = $_op;
99+ return rewriter.create<LLVM::ConstantOp>(
100+ op->getLoc(), rewriter.getIntegerType(32), val);
101+ }]
102+ >,
103+ InterfaceMethod<
104+ /*desc=*/[{
105+ This function supplies the necessary arguments for passing PTX code,
106+ following this order:
107+ 1) Adds results
108+ 2) Adds operands
109+ 3) Adds attributes
110+ }],
111+ /*retType=*/"void",
112+ /*methodName=*/"getAsmValues",
113+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
114+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
115+ /*methodBody=*/"",
116+ /*defaultImpl=*/ [{
117+ mlir::Operation* op = $_op;
118+
119+ // Step 1. Add results
120+ for (auto val : op->getResults())
121+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
122+
123+ // Step 2. Add operands
124+ for (auto val : op->getOperands())
125+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
126+
127+ // Step 3. Add attributes
128+ for (auto attr : op->getAttrs()) {
129+ if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
130+ ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
131+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
132+ }
133+ }
134+ }]
135+ >
136+ ];
137+ }
138+
139+ #endif // BASICPTXBUILDER_OP_INTERFACE
0 commit comments