|
19 | 19 |
|
20 | 20 | using namespace mlir; |
21 | 21 |
|
| 22 | +LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, |
| 23 | + Location loc, OpBuilder &b, |
| 24 | + StringRef name, |
| 25 | + LLVM::LLVMFunctionType type) { |
| 26 | + LLVM::LLVMFuncOp ret; |
| 27 | + if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
| 28 | + OpBuilder::InsertionGuard guard(b); |
| 29 | + b.setInsertionPointToStart(moduleOp.getBody()); |
| 30 | + ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); |
| 31 | + } |
| 32 | + return ret; |
| 33 | +} |
| 34 | + |
| 35 | +static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, |
| 36 | + StringRef prefix) { |
| 37 | + // Get a unique global name. |
| 38 | + unsigned stringNumber = 0; |
| 39 | + SmallString<16> stringConstName; |
| 40 | + do { |
| 41 | + stringConstName.clear(); |
| 42 | + (prefix + Twine(stringNumber++)).toStringRef(stringConstName); |
| 43 | + } while (moduleOp.lookupSymbol(stringConstName)); |
| 44 | + return stringConstName; |
| 45 | +} |
| 46 | + |
| 47 | +LLVM::GlobalOp |
| 48 | +mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, |
| 49 | + gpu::GPUModuleOp moduleOp, Type llvmI8, |
| 50 | + StringRef namePrefix, StringRef str, |
| 51 | + uint64_t alignment, unsigned addrSpace) { |
| 52 | + llvm::SmallString<20> nullTermStr(str); |
| 53 | + nullTermStr.push_back('\0'); // Null terminate for C |
| 54 | + auto globalType = |
| 55 | + LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); |
| 56 | + StringAttr attr = b.getStringAttr(nullTermStr); |
| 57 | + |
| 58 | + // Try to find existing global. |
| 59 | + for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) |
| 60 | + if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && |
| 61 | + globalOp.getValueAttr() == attr && |
| 62 | + globalOp.getAlignment().value_or(0) == alignment && |
| 63 | + globalOp.getAddrSpace() == addrSpace) |
| 64 | + return globalOp; |
| 65 | + |
| 66 | + // Not found: create new global. |
| 67 | + OpBuilder::InsertionGuard guard(b); |
| 68 | + b.setInsertionPointToStart(moduleOp.getBody()); |
| 69 | + SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); |
| 70 | + return b.create<LLVM::GlobalOp>(loc, globalType, |
| 71 | + /*isConstant=*/true, LLVM::Linkage::Internal, |
| 72 | + name, attr, alignment, addrSpace); |
| 73 | +} |
| 74 | + |
22 | 75 | LogicalResult |
23 | 76 | GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, |
24 | 77 | ConversionPatternRewriter &rewriter) const { |
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, |
328 | 381 | return success(); |
329 | 382 | } |
330 | 383 |
|
331 | | -static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { |
332 | | - const char formatStringPrefix[] = "printfFormat_"; |
333 | | - // Get a unique global name. |
334 | | - unsigned stringNumber = 0; |
335 | | - SmallString<16> stringConstName; |
336 | | - do { |
337 | | - stringConstName.clear(); |
338 | | - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); |
339 | | - } while (moduleOp.lookupSymbol(stringConstName)); |
340 | | - return stringConstName; |
341 | | -} |
342 | | - |
343 | | -/// Create an global that contains the given format string. If a global with |
344 | | -/// the same format string exists already in the module, return that global. |
345 | | -static LLVM::GlobalOp getOrCreateFormatStringConstant( |
346 | | - OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, |
347 | | - StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) { |
348 | | - llvm::SmallString<20> formatString(str); |
349 | | - formatString.push_back('\0'); // Null terminate for C |
350 | | - auto globalType = |
351 | | - LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); |
352 | | - StringAttr attr = b.getStringAttr(formatString); |
353 | | - |
354 | | - // Try to find existing global. |
355 | | - for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) |
356 | | - if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && |
357 | | - globalOp.getValueAttr() == attr && |
358 | | - globalOp.getAlignment().value_or(0) == alignment && |
359 | | - globalOp.getAddrSpace() == addrSpace) |
360 | | - return globalOp; |
361 | | - |
362 | | - // Not found: create new global. |
363 | | - OpBuilder::InsertionGuard guard(b); |
364 | | - b.setInsertionPointToStart(moduleOp.getBody()); |
365 | | - SmallString<16> name = getUniqueFormatGlobalName(moduleOp); |
366 | | - return b.create<LLVM::GlobalOp>(loc, globalType, |
367 | | - /*isConstant=*/true, LLVM::Linkage::Internal, |
368 | | - name, attr, alignment, addrSpace); |
369 | | -} |
370 | | - |
371 | | -template <typename T> |
372 | | -static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, |
373 | | - ConversionPatternRewriter &rewriter, |
374 | | - StringRef name, |
375 | | - LLVM::LLVMFunctionType type) { |
376 | | - LLVM::LLVMFuncOp ret; |
377 | | - if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
378 | | - ConversionPatternRewriter::InsertionGuard guard(rewriter); |
379 | | - rewriter.setInsertionPointToStart(moduleOp.getBody()); |
380 | | - ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, |
381 | | - LLVM::Linkage::External); |
382 | | - } |
383 | | - return ret; |
384 | | -} |
385 | | - |
386 | 384 | LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( |
387 | 385 | gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, |
388 | 386 | ConversionPatternRewriter &rewriter) const { |
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( |
420 | 418 | Value printfDesc = printfBeginCall.getResult(); |
421 | 419 |
|
422 | 420 | // Create the global op or find an existing one. |
423 | | - LLVM::GlobalOp global = getOrCreateFormatStringConstant( |
424 | | - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); |
| 421 | + LLVM::GlobalOp global = getOrCreateStringConstant( |
| 422 | + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); |
425 | 423 |
|
426 | 424 | // Get a pointer to the format string's first element and pass it to printf() |
427 | 425 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( |
502 | 500 | getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); |
503 | 501 |
|
504 | 502 | // Create the global op or find an existing one. |
505 | | - LLVM::GlobalOp global = getOrCreateFormatStringConstant( |
506 | | - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0, |
507 | | - addressSpace); |
| 503 | + LLVM::GlobalOp global = getOrCreateStringConstant( |
| 504 | + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), |
| 505 | + /*alignment=*/0, addressSpace); |
508 | 506 |
|
509 | 507 | // Get a pointer to the format string's first element |
510 | 508 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>( |
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( |
546 | 544 | getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); |
547 | 545 |
|
548 | 546 | // Create the global op or find an existing one. |
549 | | - LLVM::GlobalOp global = getOrCreateFormatStringConstant( |
550 | | - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); |
| 547 | + LLVM::GlobalOp global = getOrCreateStringConstant( |
| 548 | + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); |
551 | 549 |
|
552 | 550 | // Get a pointer to the format string's first element |
553 | 551 | Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); |
|
0 commit comments