diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 684ce37b2398c..f05f1c2dc3388 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern { /// * `i32 (i8*, ...)` static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, /*isVarArg=*/true); return llvmFnType; } @@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern { Value cst0 = builder.create(loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create( - loc, - LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 684ce37b2398c..f05f1c2dc3388 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern { /// * `i32 (i8*, ...)` static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, /*isVarArg=*/true); return llvmFnType; } @@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern { Value cst0 = builder.create(loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create( - loc, - LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 8745d14c8d483..2a572ab4de706 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1071,7 +1071,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, - LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()), + LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()), global.getSymName()); $_state.addAttributes(attrs); }]>, @@ -1079,7 +1079,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, - LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName()); + LLVM::LLVMPointerType::get($_builder.getContext()), func.getName()); $_state.addAttributes(attrs); }]> ]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 96d8fceba7066..6d2585aa30ab4 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -441,7 +441,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); - mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); + mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within @@ -449,7 +449,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( auto moduleOp = gpuPrintfOp->getParentOfType(); auto vprintfType = - LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr}); + LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); @@ -473,7 +473,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( // Get a pointer to the format string's first element Value globalPtr = rewriter.create(loc, global); Value stringStart = rewriter.create( - loc, i8Ptr, globalPtr, ArrayRef{0, 0}); + loc, getTypeConverter()->getPointerType(globalType), globalType, + globalPtr, ArrayRef{0, 0}); SmallVector types; SmallVector args; // Promote and pack the arguments into a stack allocation. @@ -490,18 +491,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Type structPtrType = LLVM::LLVMPointerType::get(structType); Value one = rewriter.create(loc, rewriter.getI64Type(), rewriter.getIndexAttr(1)); - Value tempAlloc = rewriter.create(loc, structPtrType, one, - /*alignment=*/0); + Value tempAlloc = + rewriter.create(loc, ptrType, structType, one, + /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { Value ptr = rewriter.create( - loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc, - ArrayRef{0, index}); + loc, getTypeConverter()->getPointerType(structType), structType, + tempAlloc, ArrayRef{0, index}); rewriter.create(loc, arg, ptr); } - tempAlloc = rewriter.create(loc, i8Ptr, tempAlloc); std::array printfArgs = {stringStart, tempAlloc}; rewriter.create(loc, vprintfDecl, printfArgs); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 391ccd74841dc..a8c02e32ef92b 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -542,16 +542,15 @@ gpu.module @test_module_28 { gpu.module @test_module_29 { // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00") // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00") - // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32 + // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32 // CHECK-LABEL: func @test_const_printf gpu.func @test_const_printf() { - // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr> - // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr + // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8> // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64 - // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr> - // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr> to !llvm.ptr - // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr + // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32 gpu.printf "Hello, world\n" gpu.return } @@ -559,17 +558,16 @@ gpu.module @test_module_29 { // CHECK-LABEL: func @test_printf // CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32) gpu.func @test_printf(%arg0: i32, %arg1: f32) { - // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr> - // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr + // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8> // CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64 // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64 - // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr> - // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr>) -> !llvm.ptr - // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr - // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr>) -> !llvm.ptr - // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr - // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr> to !llvm.ptr - // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr + // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)> + // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr + // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)> + // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr + // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32 gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32 gpu.return }