Skip to content

Commit

Permalink
[CIR][Lowering] Partially lower variadic builtins
Browse files Browse the repository at this point in the history
Implement lowering steps for va_start, va_end, and va_copy. The va_arg
was not implemented because it requires ABI-specific lowering.

ghstack-source-id: 1ab2923027143aa28bb7361b884a5c8ee04cfbc9
Pull Request resolved: #95
  • Loading branch information
sitio-couto authored and lanza committed Dec 20, 2023
1 parent 5a977c8 commit fa1511f
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
74 changes: 73 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

using namespace cir;
Expand Down Expand Up @@ -544,6 +545,66 @@ class CIRConstantLowering
}
};

class CIRVAStartLowering
: public mlir::OpConversionPattern<mlir::cir::VAStartOp> {
public:
using OpConversionPattern<mlir::cir::VAStartOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VAStartOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(getContext());
auto vaList = rewriter.create<mlir::LLVM::BitcastOp>(
op.getLoc(), i8PtrTy, adaptor.getOperands().front());
rewriter.replaceOpWithNewOp<mlir::LLVM::VaStartOp>(op, vaList);
return mlir::success();
}
};

class CIRVAEndLowering : public mlir::OpConversionPattern<mlir::cir::VAEndOp> {
public:
using OpConversionPattern<mlir::cir::VAEndOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VAEndOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(getContext());
auto vaList = rewriter.create<mlir::LLVM::BitcastOp>(
op.getLoc(), i8PtrTy, adaptor.getOperands().front());
rewriter.replaceOpWithNewOp<mlir::LLVM::VaEndOp>(op, vaList);
return mlir::success();
}
};

class CIRVACopyLowering
: public mlir::OpConversionPattern<mlir::cir::VACopyOp> {
public:
using OpConversionPattern<mlir::cir::VACopyOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VACopyOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(getContext());
auto dstList = rewriter.create<mlir::LLVM::BitcastOp>(
op.getLoc(), i8PtrTy, adaptor.getOperands().front());
auto srcList = rewriter.create<mlir::LLVM::BitcastOp>(
op.getLoc(), i8PtrTy, adaptor.getOperands().back());
rewriter.replaceOpWithNewOp<mlir::LLVM::VaCopyOp>(op, dstList, srcList);
return mlir::success();
}
};

class CIRVAArgLowering : public mlir::OpConversionPattern<mlir::cir::VAArgOp> {
public:
using OpConversionPattern<mlir::cir::VAArgOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VAArgOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
return op.emitError("cir.vaarg lowering is NYI");
}
};

class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
public:
using OpConversionPattern<mlir::cir::FuncOp>::OpConversionPattern;
Expand Down Expand Up @@ -997,7 +1058,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRBinOpLowering, CIRLoadLowering, CIRConstantLowering,
CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,
CIRScopeOpLowering, CIRCastOpLowering, CIRIfLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering>(
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering>(
converter, patterns.getContext());
}

Expand All @@ -1018,6 +1080,16 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter) {
// LLVM doesn't work with signed types, so we drop the CIR signs here.
return mlir::IntegerType::get(type.getContext(), type.getWidth());
});
converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type {
llvm::SmallVector<mlir::Type> llvmMembers;
for (auto ty : type.getMembers())
llvmMembers.push_back(converter.convertType(ty));
auto llvmStruct = mlir::LLVM::LLVMStructType::getIdentified(
type.getContext(), type.getTypeName());
if (llvmStruct.setBody(llvmMembers, /*isPacked=*/type.getPacked()).failed())
llvm_unreachable("Failed to set body of struct");
return llvmStruct;
});
}
} // namespace

Expand Down
40 changes: 40 additions & 0 deletions clang/test/CIR/Lowering/variadics.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: cir-tool %s -cir-to-llvm -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=MLIR

!s32i = !cir.int<s, 32>
!u32i = !cir.int<u, 32>
!u8i = !cir.int<u, 8>

!ty_22struct2E__va_list_tag22 = !cir.struct<"struct.__va_list_tag", !u32i, !u32i, !cir.ptr<!u8i>, !cir.ptr<!u8i>, #cir.recdecl.ast>

module {
cir.func @average(%arg0: !s32i, ...) -> !s32i {
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["count", init] {alignment = 4 : i64}
%1 = cir.alloca !s32i, cir.ptr <!s32i>, ["__retval"] {alignment = 4 : i64}
%2 = cir.alloca !cir.array<!ty_22struct2E__va_list_tag22 x 1>, cir.ptr <!cir.array<!ty_22struct2E__va_list_tag22 x 1>>, ["args"] {alignment = 16 : i64}
%3 = cir.alloca !cir.array<!ty_22struct2E__va_list_tag22 x 1>, cir.ptr <!cir.array<!ty_22struct2E__va_list_tag22 x 1>>, ["args_copy"] {alignment = 16 : i64}
cir.store %arg0, %0 : !s32i, cir.ptr <!s32i>
%4 = cir.cast(array_to_ptrdecay, %2 : !cir.ptr<!cir.array<!ty_22struct2E__va_list_tag22 x 1>>), !cir.ptr<!ty_22struct2E__va_list_tag22>
cir.va.start %4 : !cir.ptr<!ty_22struct2E__va_list_tag22>
// MLIR: %{{[0-9]+}} = llvm.getelementptr %{{[0-9]+}}[0] : (!llvm.ptr) -> !llvm.ptr
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: llvm.intr.vastart %{{[0-9]+}} : !llvm.ptr
%5 = cir.cast(array_to_ptrdecay, %3 : !cir.ptr<!cir.array<!ty_22struct2E__va_list_tag22 x 1>>), !cir.ptr<!ty_22struct2E__va_list_tag22>
%6 = cir.cast(array_to_ptrdecay, %2 : !cir.ptr<!cir.array<!ty_22struct2E__va_list_tag22 x 1>>), !cir.ptr<!ty_22struct2E__va_list_tag22>
cir.va.copy %6 to %5 : !cir.ptr<!ty_22struct2E__va_list_tag22>, !cir.ptr<!ty_22struct2E__va_list_tag22>
// MLIR: %{{[0-9]+}} = llvm.getelementptr %{{[0-9]+}}[0] : (!llvm.ptr) -> !llvm.ptr
// MLIR-NEXT: %{{[0-9]+}} = llvm.getelementptr %{{[0-9]+}}[0] : (!llvm.ptr) -> !llvm.ptr
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: llvm.intr.vacopy %13 to %{{[0-9]+}} : !llvm.ptr, !llvm.ptr
%7 = cir.cast(array_to_ptrdecay, %2 : !cir.ptr<!cir.array<!ty_22struct2E__va_list_tag22 x 1>>), !cir.ptr<!ty_22struct2E__va_list_tag22>
cir.va.end %7 : !cir.ptr<!ty_22struct2E__va_list_tag22>
// MLIR: %{{[0-9]+}} = llvm.getelementptr %{{[0-9]+}}[0] : (!llvm.ptr) -> !llvm.ptr
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: llvm.intr.vaend %{{[0-9]+}} : !llvm.ptr
%8 = cir.const(#cir.int<0> : !s32i) : !s32i
cir.store %8, %1 : !s32i, cir.ptr <!s32i>
%9 = cir.load %1 : cir.ptr <!s32i>, !s32i
cir.return %9 : !s32i
}
}

0 comments on commit fa1511f

Please sign in to comment.