1616include "mlir/Dialect/GPU/IR/GPUBase.td"
1717include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818
19+ //===----------------------------------------------------------------------===//
20+ // GPU kernel metadata attribute
21+ //===----------------------------------------------------------------------===//
22+
23+ def GPU_KernelMetadataAttr : GPU_Attr<"KernelMetadata", "kernel_metadata"> {
24+ let description = [{
25+ GPU attribute for storing metadata related to a compiled kernel. The
26+ attribute contains the name and arguments type of the kernel.
27+
28+ The attribute also contains optional parameters for storing the arguments
29+ attributes as well as a dictionary for additional metadata, like occupancy
30+ information or other function attributes.
31+
32+ Note: The `arg_attrs` parameter is expected to follow all the constraints
33+ imposed by the `mlir::FunctionOpInterface` interface.
34+
35+ Examples:
36+ ```mlir
37+ #gpu.kernel_metadata<@kernel1, (i32) -> (), arg_attrs = [...], metadata = {reg_count = 255, ...}>
38+ #gpu.kernel_metadata<@kernel2, (i32, f64) -> ()>
39+ ```
40+ }];
41+ let parameters = (ins
42+ "StringAttr":$name,
43+ "Type":$function_type,
44+ OptionalParameter<"ArrayAttr", "arguments attributes">:$arg_attrs,
45+ OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
46+ );
47+ let assemblyFormat = [{
48+ `<` $name `,` $function_type (`,` struct($arg_attrs, $metadata)^)? `>`
49+ }];
50+ let builders = [
51+ AttrBuilderWithInferredContext<(ins "StringAttr":$name,
52+ "Type":$functionType,
53+ CArg<"ArrayAttr", "nullptr">:$argAttrs,
54+ CArg<"DictionaryAttr",
55+ "nullptr">:$metadata), [{
56+ assert(name && "invalid name");
57+ return $_get(name.getContext(), name, functionType, argAttrs, metadata);
58+ }]>,
59+ AttrBuilderWithInferredContext<(ins "FunctionOpInterface":$kernel,
60+ CArg<"DictionaryAttr",
61+ "nullptr">:$metadata)>
62+ ];
63+ let genVerifyDecl = 1;
64+ let extraClassDeclaration = [{
65+ /// Compare two kernels based on the name.
66+ bool operator<(const KernelMetadataAttr& other) const {
67+ return getName().getValue() < other.getName().getValue();
68+ }
69+
70+ /// Returns the metadata attribute corresponding to `key` or `nullptr`
71+ /// if missing.
72+ Attribute getAttr(StringRef key) const {
73+ DictionaryAttr attrs = getMetadata();
74+ return attrs ? attrs.get(key) : nullptr;
75+ }
76+ template <typename ConcreteAttr>
77+ ConcreteAttr getAttr(StringRef key) const {
78+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
79+ }
80+ Attribute getAttr(StringAttr key) const {
81+ DictionaryAttr attrs = getMetadata();
82+ return attrs ? attrs.get(key) : nullptr;
83+ }
84+ template <typename ConcreteAttr>
85+ ConcreteAttr getAttr(StringAttr key) const {
86+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
87+ }
88+
89+ /// Returns the attribute dictionary at position `index`.
90+ DictionaryAttr getArgAttrDict(unsigned index) {
91+ ArrayAttr argArray = getArgAttrs();
92+ return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
93+ }
94+
95+ /// Return the specified attribute, if present, for the argument at 'index',
96+ /// null otherwise.
97+ Attribute getArgAttr(unsigned index, StringAttr name) {
98+ DictionaryAttr argDict = getArgAttrDict(index);
99+ return argDict ? argDict.get(name) : nullptr;
100+ }
101+ Attribute getArgAttr(unsigned index, StringRef name) {
102+ DictionaryAttr argDict = getArgAttrDict(index);
103+ return argDict ? argDict.get(name) : nullptr;
104+ }
105+
106+ /// Returns a new KernelMetadataAttr that contains `attrs` in the metadata dictionary.
107+ KernelMetadataAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
108+ }];
109+ }
110+
111+ //===----------------------------------------------------------------------===//
112+ // GPU kernel table attribute
113+ //===----------------------------------------------------------------------===//
114+
115+ def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
116+ let description = [{
117+ GPU attribute representing a list of `#gpu.kernel_metadata` attributes. This
118+ attribute supports searching kernels by name. All kernels in the table must
119+ have an unique name.
120+
121+ Examples:
122+ ```mlir
123+ // Empty table.
124+ #gpu.kernel_table<>
125+
126+ // Table with a single kernel.
127+ #gpu.kernel_table<[#gpu.kernel_metadata<kernel0, () -> () >]>
128+
129+ // Table with multiple kernels.
130+ #gpu.kernel_table<[
131+ #gpu.kernel_metadata<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
132+ #gpu.kernel_metadata<"kernel1", (i32) -> ()>
133+ ]>
134+ ```
135+ }];
136+ let parameters = (ins
137+ OptionalArrayRefParameter<"KernelMetadataAttr", "array of kernels">:$kernel_table
138+ );
139+ let assemblyFormat = [{
140+ `<` (`[` qualified($kernel_table)^ `]`)? `>`
141+ }];
142+ let builders = [
143+ AttrBuilder<(ins "ArrayRef<KernelMetadataAttr>":$kernels,
144+ CArg<"bool", "false">:$isSorted)>
145+ ];
146+ let skipDefaultBuilders = 1;
147+ let genVerifyDecl = 1;
148+ let extraClassDeclaration = [{
149+ llvm::ArrayRef<KernelMetadataAttr>::iterator begin() const {
150+ return getKernelTable().begin();
151+ }
152+ llvm::ArrayRef<KernelMetadataAttr>::iterator end() const {
153+ return getKernelTable().end();
154+ }
155+ size_t size() const {
156+ return getKernelTable().size();
157+ }
158+ bool empty() const {
159+ return getKernelTable().empty();
160+ }
161+
162+ /// Returns the kernel with name `key` or `nullptr` if not present.
163+ KernelMetadataAttr lookup(StringRef key) const;
164+ KernelMetadataAttr lookup(StringAttr key) const;
165+ }];
166+ }
167+
19168//===----------------------------------------------------------------------===//
20169// GPU object attribute.
21170//===----------------------------------------------------------------------===//
@@ -36,8 +185,9 @@ def GPU_CompilationTargetEnum : GPU_I32Enum<
36185def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
37186 let description = [{
38187 A GPU object attribute glues together a GPU target, the object kind, a
39- binary string with the object, and the object properties, encapsulating how
40- the object was generated and its properties with the object itself.
188+ binary string with the object, the object properties, and kernel metadata,
189+ encapsulating how the object was generated and its properties with the
190+ object itself.
41191
42192 There are four object formats:
43193 1. `Offload`: represents generic objects not described by the other three
@@ -55,6 +205,10 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
55205
56206 Object properties are specified through the `properties` dictionary
57207 attribute and can be used to define additional information.
208+
209+ Kernel metadata is specified through the `kernels` parameter, and can be
210+ used to specify additional information on a kernel by kernel basis.
211+
58212 The target attribute must implement or promise the `TargetAttrInterface`
59213 interface.
60214
@@ -63,16 +217,29 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
63217 #gpu.object<#nvvm.target, properties = {O = 3 : i32}, assembly = "..."> // An assembly object with additional properties.
64218 #gpu.object<#rocdl.target, bin = "..."> // A binary object.
65219 #gpu.object<#nvvm.target, "..."> // A fatbin object.
220+ #gpu.object<#nvvm.target, kernels = #gpu.kernel_table<...>, "..."> // An object with a kernel table.
66221 ```
67222 }];
68223 let parameters = (ins
69224 "Attribute":$target,
70225 DefaultValuedParameter<"CompilationTarget", "CompilationTarget::Fatbin">:$format,
71226 "StringAttr":$object,
72- OptionalParameter<"DictionaryAttr">:$properties
227+ OptionalParameter<"DictionaryAttr">:$properties,
228+ OptionalParameter<"KernelTableAttr">:$kernels
73229 );
230+ let builders = [
231+ AttrBuilderWithInferredContext<(ins "Attribute":$target,
232+ "CompilationTarget":$format,
233+ "StringAttr":$object,
234+ CArg<"DictionaryAttr", "nullptr">:$properties,
235+ CArg<"KernelTableAttr", "nullptr">:$kernels), [{
236+ assert(target && "invalid target");
237+ return $_get(target.getContext(), target, format, object, properties, kernels);
238+ }]>
239+ ];
74240 let assemblyFormat = [{ `<`
75- $target `,` (`properties` `=` $properties ^ `,`)?
241+ $target `,` (`properties` `=` $properties^ `,`)?
242+ (`kernels` `=` $kernels^ `,`)?
76243 custom<Object>($format, $object)
77244 `>`
78245 }];
0 commit comments