|
22 | 22 | #include "flang/Runtime/CUDA/memory.h" |
23 | 23 | #include "flang/Runtime/CUDA/pointer.h" |
24 | 24 | #include "flang/Runtime/allocatable.h" |
| 25 | +#include "flang/Runtime/allocator-registry-consts.h" |
25 | 26 | #include "flang/Support/Fortran.h" |
26 | 27 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
27 | 28 | #include "mlir/Dialect/DLTI/DLTI.h" |
@@ -923,6 +924,34 @@ struct CUFSyncDescriptorOpConversion |
923 | 924 | } |
924 | 925 | }; |
925 | 926 |
|
| 927 | +struct CUFSetAllocatorIndexOpConversion |
| 928 | + : public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> { |
| 929 | + using OpRewritePattern::OpRewritePattern; |
| 930 | + |
| 931 | + mlir::LogicalResult |
| 932 | + matchAndRewrite(cuf::SetAllocatorIndexOp op, |
| 933 | + mlir::PatternRewriter &rewriter) const override { |
| 934 | + auto mod = op->getParentOfType<mlir::ModuleOp>(); |
| 935 | + fir::FirOpBuilder builder(rewriter, mod); |
| 936 | + mlir::Location loc = op.getLoc(); |
| 937 | + int idx = kDefaultAllocator; |
| 938 | + if (op.getDataAttr() == cuf::DataAttribute::Device) { |
| 939 | + idx = kDeviceAllocatorPos; |
| 940 | + } else if (op.getDataAttr() == cuf::DataAttribute::Managed) { |
| 941 | + idx = kManagedAllocatorPos; |
| 942 | + } else if (op.getDataAttr() == cuf::DataAttribute::Unified) { |
| 943 | + idx = kUnifiedAllocatorPos; |
| 944 | + } else if (op.getDataAttr() == cuf::DataAttribute::Pinned) { |
| 945 | + idx = kPinnedAllocatorPos; |
| 946 | + } |
| 947 | + mlir::Value index = |
| 948 | + builder.createIntegerConstant(loc, builder.getI32Type(), idx); |
| 949 | + fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index); |
| 950 | + op.erase(); |
| 951 | + return mlir::success(); |
| 952 | + } |
| 953 | +}; |
| 954 | + |
926 | 955 | class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { |
927 | 956 | public: |
928 | 957 | void runOnOperation() override { |
@@ -984,8 +1013,8 @@ void cuf::populateCUFToFIRConversionPatterns( |
984 | 1013 | const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { |
985 | 1014 | patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); |
986 | 1015 | patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion, |
987 | | - CUFFreeOpConversion, CUFSyncDescriptorOpConversion>( |
988 | | - patterns.getContext()); |
| 1016 | + CUFFreeOpConversion, CUFSyncDescriptorOpConversion, |
| 1017 | + CUFSetAllocatorIndexOpConversion>(patterns.getContext()); |
989 | 1018 | patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, |
990 | 1019 | &dl, &converter); |
991 | 1020 | patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>( |
|
0 commit comments