@@ -722,4 +722,164 @@ def ROCDL_TargettAttr :
722722 }
723723 }];
724724}
725+
726+ //===----------------------------------------------------------------------===//
727+ // ROCDL kernel attribute
728+ //===----------------------------------------------------------------------===//
729+
730+ def ROCDL_KernelAttr :
731+ ROCDL_Attr<"ROCDLKernel", "kernel"> {
732+ let description = [{
733+ ROCDL attribute for storing metadata related to a compiled kernel. It
734+ contains the attribute dictionary of the LLVM function used to generate the
735+ kernel, as well as an optional dictionary for additional metadata, like ELF
736+ related metadata.
737+ For details on the ELF metadata see:
738+ https://llvm.org/docs/AMDGPUUsage.html#code-object-v5-metadata
739+
740+ Examples:
741+ ```mlir
742+ #rocdl.kernel<{sym_name = "test_fusion__part_0", ...},
743+ metadata = {sgpr_count = 255, ...}>
744+ ```
745+ }];
746+ let parameters = (ins
747+ "DictionaryAttr":$func_attrs,
748+ OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
749+ );
750+ let assemblyFormat = [{
751+ `<` $func_attrs (`,` `metadata` `=` $metadata^ )? `>`
752+ }];
753+ let builders = [
754+ AttrBuilderWithInferredContext<(ins "DictionaryAttr":$funcAttrs,
755+ CArg<"DictionaryAttr",
756+ "nullptr">:$metadata), [{
757+ assert(funcAttrs && "invalid function attributes dictionary");
758+ return $_get(funcAttrs.getContext(), funcAttrs, metadata);
759+ }]>
760+ ];
761+ let extraClassDeclaration = [{
762+ /// Returns the function attribute corresponding to key or nullptr if missing.
763+ Attribute getAttr(StringRef key) const {
764+ return getFuncAttrs().get(key);
765+ }
766+ template <typename ConcreteAttr>
767+ ConcreteAttr getAttr(StringRef key) const {
768+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
769+ }
770+ Attribute getAttr(StringAttr key) const;
771+ template <typename ConcreteAttr>
772+ ConcreteAttr getAttr(StringAttr key) const {
773+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
774+ }
775+
776+ /// Returns the name of the kernel.
777+ StringAttr getName() const {
778+ return getAttr<StringAttr>("sym_name");
779+ }
780+
781+ /// Returns the metadta attribute corresponding to key or nullptr if missing.
782+ Attribute getMDAttr(StringRef key) const {
783+ if (DictionaryAttr attrs = getMetadata())
784+ return attrs.get(key);
785+ return nullptr;
786+ }
787+ template <typename ConcreteAttr>
788+ ConcreteAttr getMDAttr(StringRef key) const {
789+ return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
790+ }
791+ Attribute getMDAttr(StringAttr key) const;
792+ template <typename ConcreteAttr>
793+ ConcreteAttr getMDAttr(StringAttr key) const {
794+ return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
795+ }
796+
797+ /// Returns the number of required scalar registers, or nullptr if the field
798+ /// is missing.
799+ IntegerAttr getSGPR() const {
800+ return getMDAttr<IntegerAttr>("sgpr_count");
801+ }
802+
803+ /// Returns the number of required scalar registers, or nullptr if the field
804+ /// is missing.
805+ IntegerAttr getVGPR() const {
806+ return getMDAttr<IntegerAttr>("vgpr_count");
807+ }
808+
809+ /// Returns the number of required scalar registers, or nullptr if the field
810+ /// is missing.
811+ IntegerAttr getAGPR() const {
812+ return getMDAttr<IntegerAttr>("agpr_count");
813+ }
814+
815+ /// Returns the number of spilled SGPR, or nullptr if the field is missing.
816+ IntegerAttr getSGPRSpill() const {
817+ return getMDAttr<IntegerAttr>("sgpr_spill_count");
818+ }
819+
820+ /// Returns the number of spilled VGPR, or nullptr if the field is missing.
821+ IntegerAttr getVGPRSpill() const {
822+ return getMDAttr<IntegerAttr>("vgpr_spill_count");
823+ }
824+
825+ /// Helper function for appending metadata to a kernel attribute.
826+ ROCDLKernelAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
827+ }];
828+ }
829+
830+ //===----------------------------------------------------------------------===//
831+ // ROCDL object metadata
832+ //===----------------------------------------------------------------------===//
833+
834+ def ROCDL_ObjectMDAttr :
835+ ROCDL_Attr<"ROCDLObjectMD", "object_metadata"> {
836+ let description = [{
837+ ROCDL attribute representing a table of kernels metadata. All the attributes
838+ in the dictionary must be of type `#rocdl.kernel`.
839+
840+ Examples:
841+ ```mlir
842+ #rocdl.object_metadata<{kernel0 = #rocdl.kernel<...>}>
843+ ```
844+ }];
845+ let parameters = (ins
846+ "DictionaryAttr":$kernel_table
847+ );
848+ let assemblyFormat = [{
849+ `<` $kernel_table `>`
850+ }];
851+ let builders = [
852+ AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
853+ assert(kernel_table && "invalid kernel table");
854+ return $_get(kernel_table.getContext(), kernel_table);
855+ }]>
856+ ];
857+ let skipDefaultBuilders = 1;
858+ let genVerifyDecl = 1;
859+ let extraClassDeclaration = [{
860+ /// Helper iterator class for traversing the kernel table.
861+ struct KernelIterator
862+ : llvm::mapped_iterator_base<KernelIterator,
863+ llvm::ArrayRef<NamedAttribute>::iterator,
864+ std::pair<StringAttr, ROCDLKernelAttr>> {
865+ using llvm::mapped_iterator_base<
866+ KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
867+ std::pair<StringAttr, ROCDLKernelAttr>>::mapped_iterator_base;
868+ /// Map the iterator to the kernel name and a KernelAttribute.
869+ std::pair<StringAttr, ROCDLKernelAttr> mapElement(NamedAttribute attr) const {
870+ return {attr.getName(), llvm::cast<ROCDLKernelAttr>(attr.getValue())};
871+ }
872+ };
873+ auto begin() const {
874+ return KernelIterator(getKernelTable().begin());
875+ }
876+ auto end() const {
877+ return KernelIterator(getKernelTable().end());
878+ }
879+ size_t size() const {
880+ return getKernelTable().size();
881+ }
882+ }];
883+ }
884+
725885#endif // ROCDLIR_OPS
0 commit comments