diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index d4c2efca4ecfa..0f7752eda6d66 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -102,6 +102,7 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/SandboxIR/Tracker.h" +#include "llvm/SandboxIR/Type.h" #include "llvm/SandboxIR/Use.h" #include "llvm/Support/raw_ostream.h" #include @@ -386,7 +387,7 @@ class Value { return Cnt == Num; } - Type *getType() const { return Val->getType(); } + Type *getType() const; Context &getContext() const { return Ctx; } @@ -574,8 +575,7 @@ class ConstantInt : public Constant { public: /// If Ty is a vector type, return a Constant with a splat of the given /// value. Otherwise return a ConstantInt for the given value. - static ConstantInt *get(Type *Ty, uint64_t V, Context &Ctx, - bool IsSigned = false); + static ConstantInt *get(Type *Ty, uint64_t V, bool IsSigned = false); // TODO: Implement missing functions. @@ -1024,10 +1024,7 @@ class ExtractElementInst final Value *getIndexOperand() { return getOperand(1); } const Value *getVectorOperand() const { return getOperand(0); } const Value *getIndexOperand() const { return getOperand(1); } - - VectorType *getVectorOperandType() const { - return cast(getVectorOperand()->getType()); - } + VectorType *getVectorOperandType() const; }; class ShuffleVectorInst final @@ -1072,9 +1069,7 @@ class ShuffleVectorInst final } /// Overload to return most specific vector type. - VectorType *getType() const { - return cast(Val)->getType(); - } + VectorType *getType() const; /// Return the shuffle mask value of this instruction for the given element /// index. Return PoisonMaskElem if the element is undef. @@ -1100,7 +1095,7 @@ class ShuffleVectorInst final Constant *getShuffleMaskForBitcode() const; static Constant *convertShuffleMaskForBitcode(ArrayRef Mask, - Type *ResultTy, Context &Ctx); + Type *ResultTy); void setShuffleMask(ArrayRef Mask); @@ -1646,9 +1641,7 @@ class ExtractValueInst : public UnaryInstruction { /// with an extractvalue instruction with the specified parameters. /// /// Null is returned if the indices are invalid for the specified type. - static Type *getIndexedType(Type *Agg, ArrayRef Idxs) { - return llvm::ExtractValueInst::getIndexedType(Agg, Idxs); - } + static Type *getIndexedType(Type *Agg, ArrayRef Idxs); using idx_iterator = llvm::ExtractValueInst::idx_iterator; @@ -1843,9 +1836,7 @@ class CallBase : public SingleLLVMInstructionImpl { Opc == Instruction::ClassID::CallBr; } - FunctionType *getFunctionType() const { - return cast(Val)->getFunctionType(); - } + FunctionType *getFunctionType() const; op_iterator data_operands_begin() { return op_begin(); } const_op_iterator data_operands_begin() const { @@ -2261,12 +2252,8 @@ class GetElementPtrInst final return From->getSubclassID() == ClassID::GetElementPtr; } - Type *getSourceElementType() const { - return cast(Val)->getSourceElementType(); - } - Type *getResultElementType() const { - return cast(Val)->getResultElementType(); - } + Type *getSourceElementType() const; + Type *getResultElementType() const; unsigned getAddressSpace() const { return cast(Val)->getAddressSpace(); } @@ -2290,9 +2277,7 @@ class GetElementPtrInst final static unsigned getPointerOperandIndex() { return llvm::GetElementPtrInst::getPointerOperandIndex(); } - Type *getPointerOperandType() const { - return cast(Val)->getPointerOperandType(); - } + Type *getPointerOperandType() const; unsigned getPointerAddressSpace() const { return cast(Val)->getPointerAddressSpace(); } @@ -2843,9 +2828,7 @@ class AllocaInst final : public UnaryInstruction { return const_cast(this)->getArraySize(); } /// Overload to return most specific pointer type. - PointerType *getType() const { - return cast(Val)->getType(); - } + PointerType *getType() const; /// Return the address space for the allocation. unsigned getAddressSpace() const { return cast(Val)->getAddressSpace(); @@ -2861,9 +2844,7 @@ class AllocaInst final : public UnaryInstruction { return cast(Val)->getAllocationSizeInBits(DL); } /// Return the type that is being allocated by the instruction. - Type *getAllocatedType() const { - return cast(Val)->getAllocatedType(); - } + Type *getAllocatedType() const; /// for use only in special circumstances that need to generically /// transform a whole instruction (eg: IR linking and vectorization). void setAllocatedType(Type *Ty); @@ -2945,8 +2926,8 @@ class CastInst : public UnaryInstruction { const Twine &Name = ""); /// For isa/dyn_cast. static bool classof(const Value *From); - Type *getSrcTy() const { return cast(Val)->getSrcTy(); } - Type *getDestTy() const { return cast(Val)->getDestTy(); } + Type *getSrcTy() const; + Type *getDestTy() const; }; /// Instruction that can have a nneg flag (zext/uitofp). @@ -3126,6 +3107,8 @@ class OpaqueInst : public SingleLLVMInstructionImpl { class Context { protected: LLVMContext &LLVMCtx; + friend class Type; // For LLVMCtx. + friend class PointerType; // For LLVMCtx. Tracker IRTracker; /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all @@ -3133,6 +3116,16 @@ class Context { DenseMap> LLVMValueToValueMap; + /// Type has a protected destructor to prohibit the user from managing the + /// lifetime of the Type objects. Context is friend of Type, and this custom + /// deleter can destroy Type. + struct TypeDeleter { + void operator()(Type *Ty) { delete Ty; } + }; + /// Maps LLVM Type to the corresonding sandboxir::Type. Owns all Sandbox IR + /// Type objects. + DenseMap> LLVMTypeToTypeMap; + /// Remove \p V from the maps and returns the unique_ptr. std::unique_ptr detachLLVMValue(llvm::Value *V); /// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively @@ -3167,7 +3160,6 @@ class Context { /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will /// also create all contents of the block. BasicBlock *createBasicBlock(llvm::BasicBlock *BB); - friend class BasicBlock; // For getOrCreateValue(). IRBuilder LLVMIRBuilder; @@ -3257,6 +3249,17 @@ class Context { const sandboxir::Value *getValue(const llvm::Value *V) const { return getValue(const_cast(V)); } + + Type *getType(llvm::Type *LLVMTy) { + if (LLVMTy == nullptr) + return nullptr; + auto Pair = LLVMTypeToTypeMap.insert({LLVMTy, nullptr}); + auto It = Pair.first; + if (Pair.second) + It->second = std::unique_ptr(new Type(LLVMTy, *this)); + return It->second.get(); + } + /// Create a sandboxir::Function for an existing LLVM IR \p F, including all /// blocks and instructions. /// This is the main API function for creating Sandbox IR. @@ -3303,9 +3306,7 @@ class Function : public Constant { LLVMBBToBB BBGetter(Ctx); return iterator(cast(Val)->end(), BBGetter); } - FunctionType *getFunctionType() const { - return cast(Val)->getFunctionType(); - } + FunctionType *getFunctionType() const; #ifndef NDEBUG void verify() const final { diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h new file mode 100644 index 0000000000000..4588cd2f73887 --- /dev/null +++ b/llvm/include/llvm/SandboxIR/Type.h @@ -0,0 +1,299 @@ +//===- llvm/SandboxIR/Type.h - Classes for handling data types --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a thin wrapper over llvm::Type. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SANDBOXIR_TYPE_H +#define LLVM_SANDBOXIR_TYPE_H + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm::sandboxir { + +class Context; +// Forward declare friend classes for MSVC. +class PointerType; +class VectorType; +class FunctionType; +#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS; +#include "llvm/SandboxIR/SandboxIRValues.def" + +/// Just like llvm::Type these are immutable, unique, never get freed and can +/// only be created via static factory methods. +class Type { +protected: + llvm::Type *LLVMTy; + friend class VectorType; // For LLVMTy. + friend class PointerType; // For LLVMTy. + friend class FunctionType; // For LLVMTy. + friend class Function; // For LLVMTy. + friend class CallBase; // For LLVMTy. + friend class ConstantInt; // For LLVMTy. + // Friend all instruction classes because `create()` functions use LLVMTy. +#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS; + // TODO: Friend DEF_CONST() +#include "llvm/SandboxIR/SandboxIRValues.def" + Context &Ctx; + + Type(llvm::Type *LLVMTy, Context &Ctx) : LLVMTy(LLVMTy), Ctx(Ctx) {} + friend class Context; // For constructor and ~Type(). + ~Type() = default; + +public: + /// Print the current type. + /// Omit the type details if \p NoDetails == true. + /// E.g., let %st = type { i32, i16 } + /// When \p NoDetails is true, we only print %st. + /// Put differently, \p NoDetails prints the type as if + /// inlined with the operands when printing an instruction. + void print(raw_ostream &OS, bool IsForDebug = false, + bool NoDetails = false) const { + LLVMTy->print(OS, IsForDebug, NoDetails); + } + + Context &getContext() const { return Ctx; } + + /// Return true if this is 'void'. + bool isVoidTy() const { return LLVMTy->isVoidTy(); } + + /// Return true if this is 'half', a 16-bit IEEE fp type. + bool isHalfTy() const { return LLVMTy->isHalfTy(); } + + /// Return true if this is 'bfloat', a 16-bit bfloat type. + bool isBFloatTy() const { return LLVMTy->isBFloatTy(); } + + /// Return true if this is a 16-bit float type. + bool is16bitFPTy() const { return LLVMTy->is16bitFPTy(); } + + /// Return true if this is 'float', a 32-bit IEEE fp type. + bool isFloatTy() const { return LLVMTy->isFloatTy(); } + + /// Return true if this is 'double', a 64-bit IEEE fp type. + bool isDoubleTy() const { return LLVMTy->isDoubleTy(); } + + /// Return true if this is x86 long double. + bool isX86_FP80Ty() const { return LLVMTy->isX86_FP80Ty(); } + + /// Return true if this is 'fp128'. + bool isFP128Ty() const { return LLVMTy->isFP128Ty(); } + + /// Return true if this is powerpc long double. + bool isPPC_FP128Ty() const { return LLVMTy->isPPC_FP128Ty(); } + + /// Return true if this is a well-behaved IEEE-like type, which has a IEEE + /// compatible layout as defined by APFloat::isIEEE(), and does not have + /// non-IEEE values, such as x86_fp80's unnormal values. + bool isIEEELikeFPTy() const { return LLVMTy->isIEEELikeFPTy(); } + + /// Return true if this is one of the floating-point types + bool isFloatingPointTy() const { return LLVMTy->isFloatingPointTy(); } + + /// Returns true if this is a floating-point type that is an unevaluated sum + /// of multiple floating-point units. + /// An example of such a type is ppc_fp128, also known as double-double, which + /// consists of two IEEE 754 doubles. + bool isMultiUnitFPType() const { return LLVMTy->isMultiUnitFPType(); } + + const fltSemantics &getFltSemantics() const { + return LLVMTy->getFltSemantics(); + } + + /// Return true if this is X86 AMX. + bool isX86_AMXTy() const { return LLVMTy->isX86_AMXTy(); } + + /// Return true if this is a target extension type. + bool isTargetExtTy() const { return LLVMTy->isTargetExtTy(); } + + /// Return true if this is a target extension type with a scalable layout. + bool isScalableTargetExtTy() const { return LLVMTy->isScalableTargetExtTy(); } + + /// Return true if this is a type whose size is a known multiple of vscale. + bool isScalableTy() const { return LLVMTy->isScalableTy(); } + + /// Return true if this is a FP type or a vector of FP. + bool isFPOrFPVectorTy() const { return LLVMTy->isFPOrFPVectorTy(); } + + /// Return true if this is 'label'. + bool isLabelTy() const { return LLVMTy->isLabelTy(); } + + /// Return true if this is 'metadata'. + bool isMetadataTy() const { return LLVMTy->isMetadataTy(); } + + /// Return true if this is 'token'. + bool isTokenTy() const { return LLVMTy->isTokenTy(); } + + /// True if this is an instance of IntegerType. + bool isIntegerTy() const { return LLVMTy->isIntegerTy(); } + + /// Return true if this is an IntegerType of the given width. + bool isIntegerTy(unsigned Bitwidth) const { + return LLVMTy->isIntegerTy(Bitwidth); + } + + /// Return true if this is an integer type or a vector of integer types. + bool isIntOrIntVectorTy() const { return LLVMTy->isIntOrIntVectorTy(); } + + /// Return true if this is an integer type or a vector of integer types of + /// the given width. + bool isIntOrIntVectorTy(unsigned BitWidth) const { + return LLVMTy->isIntOrIntVectorTy(BitWidth); + } + + /// Return true if this is an integer type or a pointer type. + bool isIntOrPtrTy() const { return LLVMTy->isIntOrPtrTy(); } + + /// True if this is an instance of FunctionType. + bool isFunctionTy() const { return LLVMTy->isFunctionTy(); } + + /// True if this is an instance of StructType. + bool isStructTy() const { return LLVMTy->isStructTy(); } + + /// True if this is an instance of ArrayType. + bool isArrayTy() const { return LLVMTy->isArrayTy(); } + + /// True if this is an instance of PointerType. + bool isPointerTy() const { return LLVMTy->isPointerTy(); } + + /// Return true if this is a pointer type or a vector of pointer types. + bool isPtrOrPtrVectorTy() const { return LLVMTy->isPtrOrPtrVectorTy(); } + + /// True if this is an instance of VectorType. + inline bool isVectorTy() const { return LLVMTy->isVectorTy(); } + + /// Return true if this type could be converted with a lossless BitCast to + /// type 'Ty'. For example, i8* to i32*. BitCasts are valid for types of the + /// same size only where no re-interpretation of the bits is done. + /// Determine if this type could be losslessly bitcast to Ty + bool canLosslesslyBitCastTo(Type *Ty) const { + return LLVMTy->canLosslesslyBitCastTo(Ty->LLVMTy); + } + + /// Return true if this type is empty, that is, it has no elements or all of + /// its elements are empty. + bool isEmptyTy() const { return LLVMTy->isEmptyTy(); } + + /// Return true if the type is "first class", meaning it is a valid type for a + /// Value. + bool isFirstClassType() const { return LLVMTy->isFirstClassType(); } + + /// Return true if the type is a valid type for a register in codegen. This + /// includes all first-class types except struct and array types. + bool isSingleValueType() const { return LLVMTy->isSingleValueType(); } + + /// Return true if the type is an aggregate type. This means it is valid as + /// the first operand of an insertvalue or extractvalue instruction. This + /// includes struct and array types, but does not include vector types. + bool isAggregateType() const { return LLVMTy->isAggregateType(); } + + /// Return true if it makes sense to take the size of this type. To get the + /// actual size for a particular target, it is reasonable to use the + /// DataLayout subsystem to do this. + bool isSized(SmallPtrSetImpl *Visited = nullptr) const { + SmallPtrSet LLVMVisited; + LLVMVisited.reserve(Visited->size()); + for (Type *Ty : *Visited) + LLVMVisited.insert(Ty->LLVMTy); + return LLVMTy->isSized(&LLVMVisited); + } + + /// Return the basic size of this type if it is a primitive type. These are + /// fixed by LLVM and are not target-dependent. + /// This will return zero if the type does not have a size or is not a + /// primitive type. + /// + /// If this is a scalable vector type, the scalable property will be set and + /// the runtime size will be a positive integer multiple of the base size. + /// + /// Note that this may not reflect the size of memory allocated for an + /// instance of the type or the number of bytes that are written when an + /// instance of the type is stored to memory. The DataLayout class provides + /// additional query functions to provide this information. + /// + TypeSize getPrimitiveSizeInBits() const { + return LLVMTy->getPrimitiveSizeInBits(); + } + + /// If this is a vector type, return the getPrimitiveSizeInBits value for the + /// element type. Otherwise return the getPrimitiveSizeInBits value for this + /// type. + unsigned getScalarSizeInBits() const { return LLVMTy->getScalarSizeInBits(); } + + /// Return the width of the mantissa of this type. This is only valid on + /// floating-point types. If the FP type does not have a stable mantissa (e.g. + /// ppc long double), this method returns -1. + int getFPMantissaWidth() const { return LLVMTy->getFPMantissaWidth(); } + + /// Return whether the type is IEEE compatible, as defined by the eponymous + /// method in APFloat. + bool isIEEE() const { return LLVMTy->isIEEE(); } + + /// If this is a vector type, return the element type, otherwise return + /// 'this'. + Type *getScalarType() const; + + // TODO: ADD MISSING + + static Type *getInt64Ty(Context &Ctx); + static Type *getInt32Ty(Context &Ctx); + static Type *getInt16Ty(Context &Ctx); + static Type *getInt8Ty(Context &Ctx); + static Type *getInt1Ty(Context &Ctx); + static Type *getDoubleTy(Context &Ctx); + static Type *getFloatTy(Context &Ctx); + // TODO: missing get* + + /// Get the address space of this pointer or pointer vector type. + inline unsigned getPointerAddressSpace() const { + return LLVMTy->getPointerAddressSpace(); + } + +#ifndef NDEBUG + void dumpOS(raw_ostream &OS) { LLVMTy->print(OS); } + LLVM_DUMP_METHOD void dump() { + dumpOS(dbgs()); + dbgs() << "\n"; + } +#endif // NDEBUG +}; + +class PointerType : public Type { +public: + // TODO: add missing functions + static PointerType *get(Type *ElementType, unsigned AddressSpace); + static PointerType *get(Context &Ctx, unsigned AddressSpace); + + static bool classof(const Type *From) { + return isa(From->LLVMTy); + } +}; + +class VectorType : public Type { +public: + // TODO: add missing functions + static bool classof(const Type *From) { + return isa(From->LLVMTy); + } +}; + +class FunctionType : public Type { +public: + // TODO: add missing functions + static bool classof(const Type *From) { + return isa(From->LLVMTy); + } +}; + +} // namespace llvm::sandboxir + +#endif // LLVM_SANDBOXIR_TYPE_H diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt index 6c0666b186b8a..d94f0642ccc4a 100644 --- a/llvm/lib/SandboxIR/CMakeLists.txt +++ b/llvm/lib/SandboxIR/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMSandboxIR SandboxIR.cpp Tracker.cpp + Type.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index b5786cdafd630..bf224b73f3bad 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -135,6 +135,8 @@ Value::user_iterator Value::user_begin() { unsigned Value::getNumUses() const { return range_size(Val->users()); } +Type *Value::getType() const { return Ctx.getType(Val->getType()); } + void Value::replaceUsesWithIf( Value *OtherV, llvm::function_ref ShouldReplace) { assert(getType() == OtherV->getType() && "Can't replace with different type"); @@ -583,7 +585,8 @@ VAArgInst *VAArgInst::create(Value *List, Type *Ty, BBIterator WhereIt, Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction()); else Builder.SetInsertPoint(cast(WhereBB->Val)); - auto *LLVMI = cast(Builder.CreateVAArg(List->Val, Ty, Name)); + auto *LLVMI = + cast(Builder.CreateVAArg(List->Val, Ty->LLVMTy, Name)); return Ctx.createVAArgInst(LLVMI); } @@ -754,7 +757,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align, auto &Builder = Ctx.getLLVMIRBuilder(); Builder.SetInsertPoint(BeforeIR); auto *NewLI = - Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name); + Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name); auto *NewSBI = Ctx.createLoadInst(NewLI); return NewSBI; } @@ -771,7 +774,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align, auto &Builder = Ctx.getLLVMIRBuilder(); Builder.SetInsertPoint(cast(InsertAtEnd->Val)); auto *NewLI = - Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name); + Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name); auto *NewSBI = Ctx.createLoadInst(NewLI); return NewSBI; } @@ -886,6 +889,11 @@ Value *ReturnInst::getReturnValue() const { return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr; } +FunctionType *CallBase::getFunctionType() const { + return cast( + Ctx.getType(cast(Val)->getFunctionType())); +} + Value *CallBase::getCalledOperand() const { return Ctx.getValue(cast(Val)->getCalledOperand()); } @@ -911,8 +919,9 @@ void CallBase::setCalledFunction(Function *F) { // Note: This may break if `setCalledFunction()` early returns if `F` // is already set, but we do have a unit test for it. setCalledOperand(F); - cast(Val)->setCalledFunction(F->getFunctionType(), - cast(F->Val)); + cast(Val)->setCalledFunction( + cast(F->getFunctionType()->LLVMTy), + cast(F->Val)); } CallInst *CallInst::create(FunctionType *FTy, Value *Func, @@ -928,7 +937,8 @@ CallInst *CallInst::create(FunctionType *FTy, Value *Func, LLVMArgs.reserve(Args.size()); for (Value *Arg : Args) LLVMArgs.push_back(Arg->Val); - llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr); + llvm::CallInst *NewCI = Builder.CreateCall( + cast(FTy->LLVMTy), Func->Val, LLVMArgs, NameStr); return Ctx.createCallInst(NewCI); } @@ -961,7 +971,8 @@ InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func, for (Value *Arg : Args) LLVMArgs.push_back(Arg->Val); llvm::InvokeInst *Invoke = Builder.CreateInvoke( - FTy, Func->Val, cast(IfNormal->Val), + cast(FTy->LLVMTy), Func->Val, + cast(IfNormal->Val), cast(IfException->Val), LLVMArgs, NameStr); return Ctx.createInvokeInst(Invoke); } @@ -1032,9 +1043,10 @@ CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func, for (Value *Arg : Args) LLVMArgs.push_back(Arg->Val); - llvm::CallBrInst *CallBr = Builder.CreateCallBr( - FTy, Func->Val, cast(DefaultDest->Val), - LLVMIndirectDests, LLVMArgs, NameStr); + llvm::CallBrInst *CallBr = + Builder.CreateCallBr(cast(FTy->LLVMTy), Func->Val, + cast(DefaultDest->Val), + LLVMIndirectDests, LLVMArgs, NameStr); return Ctx.createCallBrInst(CallBr); } @@ -1107,7 +1119,7 @@ LandingPadInst *LandingPadInst::create(Type *RetTy, unsigned NumReservedClauses, else Builder.SetInsertPoint(cast(WhereBB->Val)); llvm::LandingPadInst *LLVMI = - Builder.CreateLandingPad(RetTy, NumReservedClauses, Name); + Builder.CreateLandingPad(RetTy->LLVMTy, NumReservedClauses, Name); return Ctx.createLandingPadInst(LLVMI); } @@ -1288,7 +1300,8 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr, LLVMIdxList.reserve(IdxList.size()); for (Value *Idx : IdxList) LLVMIdxList.push_back(Idx->Val); - llvm::Value *NewV = Builder.CreateGEP(Ty, Ptr->Val, LLVMIdxList, NameStr); + llvm::Value *NewV = + Builder.CreateGEP(Ty->LLVMTy, Ptr->Val, LLVMIdxList, NameStr); if (auto *NewGEP = dyn_cast(NewV)) return Ctx.createGetElementPtrInst(NewGEP); assert(isa(NewV) && "Expected constant"); @@ -1312,10 +1325,25 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr, InsertAtEnd, Ctx, NameStr); } +Type *GetElementPtrInst::getSourceElementType() const { + return Ctx.getType( + cast(Val)->getSourceElementType()); +} + +Type *GetElementPtrInst::getResultElementType() const { + return Ctx.getType( + cast(Val)->getResultElementType()); +} + Value *GetElementPtrInst::getPointerOperand() const { return Ctx.getValue(cast(Val)->getPointerOperand()); } +Type *GetElementPtrInst::getPointerOperandType() const { + return Ctx.getType( + cast(Val)->getPointerOperandType()); +} + BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const { return cast(Ctx.getValue(LLVMBB)); } @@ -1323,8 +1351,9 @@ BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const { PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues, Instruction *InsertBefore, Context &Ctx, const Twine &Name) { - llvm::PHINode *NewPHI = llvm::PHINode::Create( - Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction()); + llvm::PHINode *NewPHI = + llvm::PHINode::Create(Ty->LLVMTy, NumReservedValues, Name, + InsertBefore->getTopmostLLVMInstruction()); return Ctx.createPHINode(NewPHI); } @@ -1943,7 +1972,8 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace, BBIterator WhereIt, Builder.SetInsertPoint(cast(WhereBB->Val)); else Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction()); - auto *NewAlloca = Builder.CreateAlloca(Ty, AddrSpace, ArraySize->Val, Name); + auto *NewAlloca = + Builder.CreateAlloca(Ty->LLVMTy, AddrSpace, ArraySize->Val, Name); return Ctx.createAllocaInst(NewAlloca); } @@ -1961,11 +1991,15 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace, Name); } +Type *AllocaInst::getAllocatedType() const { + return Ctx.getType(cast(Val)->getAllocatedType()); +} + void AllocaInst::setAllocatedType(Type *Ty) { Ctx.getTracker() .emplaceIfTracking>(this); - cast(Val)->setAllocatedType(Ty); + cast(Val)->setAllocatedType(Ty->LLVMTy); } void AllocaInst::setAlignment(Align Align) { @@ -1987,6 +2021,10 @@ Value *AllocaInst::getArraySize() { return Ctx.getValue(cast(Val)->getArraySize()); } +PointerType *AllocaInst::getType() const { + return cast(Ctx.getType(cast(Val)->getType())); +} + Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand, BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx, const Twine &Name) { @@ -1997,7 +2035,7 @@ Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand, else Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction()); auto *NewV = - Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy, Name); + Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy->LLVMTy, Name); if (auto *NewCI = dyn_cast(NewV)) return Ctx.createCastInst(NewCI); assert(isa(NewV) && "Expected constant"); @@ -2022,6 +2060,14 @@ bool CastInst::classof(const Value *From) { return From->getSubclassID() == ClassID::Cast; } +Type *CastInst::getSrcTy() const { + return Ctx.getType(cast(Val)->getSrcTy()); +} + +Type *CastInst::getDestTy() const { + return Ctx.getType(cast(Val)->getDestTy()); +} + void PossiblyNonNegInst::setNonNeg(bool B) { Ctx.getTracker() .emplaceIfTracking Mask) { cast(Val)->setShuffleMask(Mask); } +VectorType *ShuffleVectorInst::getType() const { + return cast( + Ctx.getType(cast(Val)->getType())); +} + Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const { return Ctx.getOrCreateConstant( cast(Val)->getShuffleMaskForBitcode()); } -Constant *ShuffleVectorInst::convertShuffleMaskForBitcode( - llvm::ArrayRef Mask, llvm::Type *ResultTy, Context &Ctx) { - return Ctx.getOrCreateConstant( - llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy)); +Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef Mask, + Type *ResultTy) { + return ResultTy->getContext().getOrCreateConstant( + llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, + ResultTy->LLVMTy)); +} + +VectorType *ExtractElementInst::getVectorOperandType() const { + return cast(Ctx.getType(getVectorOperand()->getType()->LLVMTy)); } Value *ExtractValueInst::create(Value *Agg, ArrayRef Idxs, @@ -2160,6 +2216,11 @@ Value *ExtractValueInst::create(Value *Agg, ArrayRef Idxs, return Ctx.getOrCreateConstant(cast(NewV)); } +Type *ExtractValueInst::getIndexedType(Type *Agg, ArrayRef Idxs) { + auto *LLVMTy = llvm::ExtractValueInst::getIndexedType(Agg->LLVMTy, Idxs); + return Agg->getContext().getType(LLVMTy); +} + Value *InsertValueInst::create(Value *Agg, Value *Val, ArrayRef Idxs, BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx, const Twine &Name) { @@ -2182,10 +2243,14 @@ void Constant::dumpOS(raw_ostream &OS) const { } #endif // NDEBUG -ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, Context &Ctx, - bool IsSigned) { - auto *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned); - return cast(Ctx.getOrCreateConstant(LLVMC)); +ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, bool IsSigned) { + auto *LLVMC = llvm::ConstantInt::get(Ty->LLVMTy, V, IsSigned); + return cast(Ty->getContext().getOrCreateConstant(LLVMC)); +} + +FunctionType *Function::getFunctionType() const { + return cast( + Ctx.getType(cast(Val)->getFunctionType())); } #ifndef NDEBUG diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp new file mode 100644 index 0000000000000..6f850b82d2e99 --- /dev/null +++ b/llvm/lib/SandboxIR/Type.cpp @@ -0,0 +1,48 @@ +//===- Type.cpp - Sandbox IR Type -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/SandboxIR/Type.h" +#include "llvm/SandboxIR/SandboxIR.h" + +using namespace llvm::sandboxir; + +Type *Type::getScalarType() const { + return Ctx.getType(LLVMTy->getScalarType()); +} + +Type *Type::getInt64Ty(Context &Ctx) { + return Ctx.getType(llvm::Type::getInt64Ty(Ctx.LLVMCtx)); +} +Type *Type::getInt32Ty(Context &Ctx) { + return Ctx.getType(llvm::Type::getInt32Ty(Ctx.LLVMCtx)); +} +Type *Type::getInt16Ty(Context &Ctx) { + return Ctx.getType(llvm::Type::getInt16Ty(Ctx.LLVMCtx)); +} +Type *Type::getInt8Ty(Context &Ctx) { + return Ctx.getType(llvm::Type::getInt8Ty(Ctx.LLVMCtx)); +} +Type *Type::getInt1Ty(Context &Ctx) { + return Ctx.getType(llvm::Type::getInt1Ty(Ctx.LLVMCtx)); +} +Type *Type::getDoubleTy(Context &Ctx) { + return Ctx.getType(llvm::Type::getDoubleTy(Ctx.LLVMCtx)); +} +Type *Type::getFloatTy(Context &Ctx) { + return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx)); +} + +PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) { + return cast(ElementType->getContext().getType( + llvm::PointerType::get(ElementType->LLVMTy, AddressSpace))); +} + +PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) { + return cast( + Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace))); +} diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt index 3f43f6337b919..2da936bffa02b 100644 --- a/llvm/unittests/SandboxIR/CMakeLists.txt +++ b/llvm/unittests/SandboxIR/CMakeLists.txt @@ -7,4 +7,5 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(SandboxIRTests SandboxIRTest.cpp TrackerTest.cpp + TypesTest.cpp ) diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 8bf4b24c48ee0..c543846eb2686 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -121,10 +121,12 @@ define void @foo(i32 %v0) { auto *FortyTwo = cast(Add0->getOperand(1)); // Check that creating an identical constant gives us the same object. - auto *NewCI = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx); + auto *NewCI = + sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42); EXPECT_EQ(NewCI, FortyTwo); // Check new constant. - auto *FortyThree = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 43, Ctx); + auto *FortyThree = + sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 43); EXPECT_NE(FortyThree, FortyTwo); } @@ -603,7 +605,7 @@ define void @foo(ptr %va) { EXPECT_EQ(sandboxir::VAArgInst::getPointerOperandIndex(), llvm::VAArgInst::getPointerOperandIndex()); // Check create(). - auto *NewVATy = Type::getInt8Ty(C); + auto *NewVATy = sandboxir::Type::getInt8Ty(Ctx); auto *NewVA = sandboxir::VAArgInst::create(Arg, NewVATy, Ret->getIterator(), Ret->getParent(), Ctx, "NewVA"); EXPECT_EQ(NewVA->getNextNode(), Ret); @@ -743,10 +745,10 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) { } { // Check SelectInst::create() Folded. - auto *False = sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 0, Ctx, - /*IsSigned=*/false); + auto *False = sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx), + 0, /*IsSigned=*/false); auto *FortyTwo = - sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 42, Ctx, + sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx), 42, /*IsSigned=*/false); auto *NewSel = sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx); @@ -838,7 +840,7 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) { auto *LLVMArg0 = LLVMF.getArg(0); auto *LLVMArgVec = LLVMF.getArg(2); - auto *Zero = sandboxir::ConstantInt::get(Type::getInt8Ty(C), 0, Ctx); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0); auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0); EXPECT_EQ( sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero), @@ -950,7 +952,7 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) { // convertShuffleMaskForBitcode { auto *C = sandboxir::ShuffleVectorInst::convertShuffleMaskForBitcode( - ArrayRef({2, 3}), ArgV1->getType(), Ctx); + ArrayRef({2, 3}), ArgV1->getType()); SmallVector Result; sandboxir::ShuffleVectorInst::getShuffleMask(C, Result); EXPECT_THAT(Result, testing::ElementsAre(2, 3)); @@ -1271,6 +1273,12 @@ define void @foo({i32, float} %agg) { } )IR"); Function &LLVMF = *M->getFunction("foo"); + auto *LLVMBB = &*LLVMF.begin(); + auto LLVMIt = LLVMBB->begin(); + [[maybe_unused]] auto *LLVMExtSimple = + cast(&*LLVMIt++); + auto *LLVMExtNested = cast(&*LLVMIt++); + sandboxir::Context Ctx(C); auto &F = *Ctx.createFunction(&LLVMF); auto *ArgAgg = F.getArg(0); @@ -1307,15 +1315,16 @@ define void @foo({i32, float} %agg) { Const1->getOperand(0), ArrayRef({0}), BB->end(), BB, Ctx); EXPECT_TRUE(isa(ShouldBeConstant)); - auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0); EXPECT_EQ(ShouldBeConstant, Zero); // getIndexedType - Type *AggType = ExtNested->getAggregateOperand()->getType(); + sandboxir::Type *AggType = ExtNested->getAggregateOperand()->getType(); + llvm::Type *LLVMAggType = LLVMExtNested->getAggregateOperand()->getType(); EXPECT_EQ(sandboxir::ExtractValueInst::getIndexedType( AggType, ArrayRef({1, 0})), - llvm::ExtractValueInst::getIndexedType(AggType, - ArrayRef({1, 0}))); + Ctx.getType(llvm::ExtractValueInst::getIndexedType( + LLVMAggType, ArrayRef({1, 0})))); EXPECT_EQ(sandboxir::ExtractValueInst::getIndexedType( AggType, ArrayRef({2})), @@ -1410,7 +1419,7 @@ define void @foo({i32, float} %agg, i32 %i) { #endif // NDEBUG // Test the path that creates a folded constant. - auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0); auto *ShouldBeConstant = sandboxir::InsertValueInst::create( Const1->getOperand(0), Zero, ArrayRef({0}), BB->end(), BB, Ctx); auto *ExpectedConstant = Const2->getOperand(0); @@ -1811,7 +1820,8 @@ define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) { // Check classof(Value *). EXPECT_TRUE(isa((sandboxir::Value *)Call)); // Check getFunctionType(). - EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType()); + EXPECT_EQ(Call->getFunctionType(), + Ctx.getType(LLVMCall->getFunctionType())); // Check data_ops(). EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops())); auto DataOpIt = Call->data_operands_begin(); @@ -1942,7 +1952,7 @@ define i8 @foo(i8 %arg) { auto *Ret = cast(&*It++); EXPECT_EQ(Call->getNumOperands(), 2u); EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret); - FunctionType *FTy = F.getFunctionType(); + sandboxir::FunctionType *FTy = F.getFunctionType(); SmallVector Args; Args.push_back(Arg0); { @@ -2231,8 +2241,8 @@ define void @foo() { auto *BBRet = &*BB->begin(); auto *NewLPad = cast(sandboxir::LandingPadInst::create( - Type::getInt8Ty(C), 0, BBRet->getIterator(), BBRet->getParent(), Ctx, - "NewLPad")); + sandboxir::Type::getInt8Ty(Ctx), 0, BBRet->getIterator(), + BBRet->getParent(), Ctx, "NewLPad")); EXPECT_EQ(NewLPad->getNextNode(), BBRet); EXPECT_FALSE(NewLPad->isCleanup()); #ifndef NDEBUG @@ -2491,9 +2501,11 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) { // Check classof(). auto *GEP = cast(Ctx.getValue(LLVMGEP)); // Check getSourceElementType(). - EXPECT_EQ(GEP->getSourceElementType(), LLVMGEP->getSourceElementType()); + EXPECT_EQ(GEP->getSourceElementType(), + Ctx.getType(LLVMGEP->getSourceElementType())); // Check getResultElementType(). - EXPECT_EQ(GEP->getResultElementType(), LLVMGEP->getResultElementType()); + EXPECT_EQ(GEP->getResultElementType(), + Ctx.getType(LLVMGEP->getResultElementType())); // Check getAddressSpace(). EXPECT_EQ(GEP->getAddressSpace(), LLVMGEP->getAddressSpace()); // Check indices(). @@ -2509,7 +2521,8 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) { // Check getPointerOperandIndex(). EXPECT_EQ(GEP->getPointerOperandIndex(), LLVMGEP->getPointerOperandIndex()); // Check getPointerOperandType(). - EXPECT_EQ(GEP->getPointerOperandType(), LLVMGEP->getPointerOperandType()); + EXPECT_EQ(GEP->getPointerOperandType(), + Ctx.getType(LLVMGEP->getPointerOperandType())); // Check getPointerAddressSpace(). EXPECT_EQ(GEP->getPointerAddressSpace(), LLVMGEP->getPointerAddressSpace()); // Check getNumIndices(). @@ -2870,8 +2883,8 @@ define void @foo(i32 %cond0, i32 %cond1) { Switch->setSuccessor(0, OrigSucc); EXPECT_EQ(Switch->getSuccessor(0), OrigSucc); // Check case_begin(), case_end(), CaseIt. - auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx); - auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0); + auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1); auto CaseIt = Switch->case_begin(); { sandboxir::SwitchInst::CaseHandle Case = *CaseIt++; @@ -2908,7 +2921,7 @@ define void @foo(i32 %cond0, i32 %cond1) { EXPECT_EQ(Switch->findCaseDest(BB1), One); EXPECT_EQ(Switch->findCaseDest(Entry), nullptr); // Check addCase(). - auto *Two = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 2, Ctx); + auto *Two = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 2); Switch->addCase(Two, Entry); auto CaseTwoIt = Switch->findCaseValue(Two); auto CaseTwo = *CaseTwoIt; @@ -3173,7 +3186,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) { } { // Check create() when it gets folded. - auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx); + auto *FortyTwo = + sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42); auto *NewV = sandboxir::BinaryOperator::create( sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo, /*InsertBefore=*/Ret, Ctx, "Folded"); @@ -3229,7 +3243,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) { } { // Check createWithCopiedFlags() when it gets folded. - auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx); + auto *FortyTwo = + sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42); auto *NewV = sandboxir::BinaryOperator::createWithCopiedFlags( sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo, CopyFrom, /*InsertBefore=*/Ret, Ctx, "Folded"); @@ -3651,8 +3666,8 @@ define void @foo() { EXPECT_EQ(AllocaArray->getArraySize(), Ctx.getValue(LLVMAllocaArray->getArraySize())); // Check getType(). - EXPECT_EQ(AllocaScalar->getType(), LLVMAllocaScalar->getType()); - EXPECT_EQ(AllocaArray->getType(), LLVMAllocaArray->getType()); + EXPECT_EQ(AllocaScalar->getType(), Ctx.getType(LLVMAllocaScalar->getType())); + EXPECT_EQ(AllocaArray->getType(), Ctx.getType(LLVMAllocaArray->getType())); // Check getAddressSpace(). EXPECT_EQ(AllocaScalar->getAddressSpace(), LLVMAllocaScalar->getAddressSpace()); @@ -3669,12 +3684,12 @@ define void @foo() { LLVMAllocaArray->getAllocationSizeInBits(DL)); // Check getAllocatedType(). EXPECT_EQ(AllocaScalar->getAllocatedType(), - LLVMAllocaScalar->getAllocatedType()); + Ctx.getType(LLVMAllocaScalar->getAllocatedType())); EXPECT_EQ(AllocaArray->getAllocatedType(), - LLVMAllocaArray->getAllocatedType()); + Ctx.getType(LLVMAllocaArray->getAllocatedType())); // Check setAllocatedType(). auto *OrigType = AllocaScalar->getAllocatedType(); - auto *NewType = PointerType::get(C, 0); + auto *NewType = sandboxir::PointerType::get(Ctx, 0); EXPECT_NE(NewType, OrigType); AllocaScalar->setAllocatedType(NewType); EXPECT_EQ(AllocaScalar->getAllocatedType(), NewType); @@ -3705,10 +3720,10 @@ define void @foo() { AllocaScalar->setUsedWithInAlloca(OrigUsedWithInAlloca); EXPECT_EQ(AllocaScalar->isUsedWithInAlloca(), OrigUsedWithInAlloca); - auto *Ty = Type::getInt32Ty(C); + auto *Ty = sandboxir::Type::getInt32Ty(Ctx); unsigned AddrSpace = 42; - auto *PtrTy = PointerType::get(C, AddrSpace); - auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43, Ctx); + auto *PtrTy = sandboxir::PointerType::get(Ctx, AddrSpace); + auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43); { // Check create() WhereIt, WhereBB. auto *NewI = cast(sandboxir::AllocaInst::create( @@ -3785,13 +3800,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) { auto *BB = &*F->begin(); auto It = BB->begin(); - Type *Ti64 = Type::getInt64Ty(C); - Type *Ti32 = Type::getInt32Ty(C); - Type *Ti16 = Type::getInt16Ty(C); - Type *Tdouble = Type::getDoubleTy(C); - Type *Tfloat = Type::getFloatTy(C); - Type *Tptr = Tfloat->getPointerTo(); - Type *Tptr1 = Tfloat->getPointerTo(1); + auto *Ti64 = sandboxir::Type::getInt64Ty(Ctx); + auto *Ti32 = sandboxir::Type::getInt32Ty(Ctx); + auto *Ti16 = sandboxir::Type::getInt16Ty(Ctx); + auto *Tdouble = sandboxir::Type::getDoubleTy(Ctx); + auto *Tfloat = sandboxir::Type::getFloatTy(Ctx); + auto *Tptr = sandboxir::PointerType::get(Tfloat, 0); + auto *Tptr1 = sandboxir::PointerType::get(Tfloat, 1); // Check classof(), getOpcode(), getSrcTy(), getDstTy() auto *ZExt = cast(&*It++); @@ -4003,10 +4018,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) { /// CastInst's subclasses are very similar so we can use a common test function /// for them. template -void testCastInst(llvm::Module &M, Type *SrcTy, Type *DstTy) { +void testCastInst(llvm::Module &M, llvm::Type *LLVMSrcTy, + llvm::Type *LLVMDstTy) { Function &LLVMF = *M.getFunction("foo"); sandboxir::Context Ctx(M.getContext()); sandboxir::Function *F = Ctx.createFunction(&LLVMF); + sandboxir::Type *SrcTy = Ctx.getType(LLVMSrcTy); + sandboxir::Type *DstTy = Ctx.getType(LLVMDstTy); unsigned ArgIdx = 0; auto *Arg = F->getArg(ArgIdx++); auto *BB = &*F->begin(); diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index ca6effb727bf3..c189100fbd694 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -938,9 +938,10 @@ define void @foo(i32 %cond0, i32 %cond1) { Ctx.revert(); EXPECT_EQ(Switch->getSuccessor(0), OrigSucc); // Check addCase(). - auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx); - auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx); - auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx); + auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0); + auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1); + auto *FortyTwo = + sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42); Ctx.save(); Switch->addCase(FortyTwo, Entry); EXPECT_EQ(Switch->getNumCases(), 3u); @@ -1187,7 +1188,7 @@ define void @foo(i8 %arg) { // Check setAllocatedType(). Ctx.save(); auto *OrigTy = Alloca->getAllocatedType(); - auto *NewTy = Type::getInt64Ty(C); + auto *NewTy = sandboxir::Type::getInt64Ty(Ctx); EXPECT_NE(NewTy, OrigTy); Alloca->setAllocatedType(NewTy); EXPECT_EQ(Alloca->getAllocatedType(), NewTy); diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp new file mode 100644 index 0000000000000..cd9e14dced7fe --- /dev/null +++ b/llvm/unittests/SandboxIR/TypesTest.cpp @@ -0,0 +1,253 @@ +//===- TypesTest.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct SandboxTypeTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("SandboxTypeTest", errs()); + } + BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); + } +}; + +TEST_F(SandboxTypeTest, Type) { + parseIR(C, R"IR( +define void @foo(i32 %v0) { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + sandboxir::Type *I32Ty = F->getArg(0)->getType(); + + auto *LLVMInt32Ty = llvm::Type::getInt32Ty(C); + auto *LLVMFloatTy = llvm::Type::getFloatTy(C); + auto *LLVMInt8Ty = llvm::Type::getInt8Ty(C); + + auto *Int32Ty = Ctx.getType(LLVMInt32Ty); + auto *FloatTy = Ctx.getType(LLVMFloatTy); + + // Check print(). + std::string Buff1; + raw_string_ostream BS1(Buff1); + Int32Ty->print(BS1, /*IsForDebug=*/true, /*NoDetails=*/false); + std::string Buff2; + raw_string_ostream BS2(Buff2); + LLVMInt32Ty->print(BS2, /*IsForDebug=*/true, /*NoDetails=*/false); + EXPECT_EQ(Buff1, Buff2); + + // Check getContext(). + EXPECT_EQ(&I32Ty->getContext(), &Ctx); + // Check that Ctx.getType(nullptr) == nullptr. + EXPECT_EQ(Ctx.getType(nullptr), nullptr); + +#define CHK(LLVMCreate, SBCheck) \ + Ctx.getType(llvm::Type::LLVMCreate(C))->SBCheck() + // Check isVoidTy(). + EXPECT_TRUE(Ctx.getType(llvm::Type::getVoidTy(C))->isVoidTy()); + EXPECT_TRUE(CHK(getVoidTy, isVoidTy)); + // Check isHalfTy(). + EXPECT_TRUE(CHK(getHalfTy, isHalfTy)); + // Check isBFloatTy(). + EXPECT_TRUE(CHK(getBFloatTy, isBFloatTy)); + // Check is16bitFPTy(). + EXPECT_TRUE(CHK(getHalfTy, is16bitFPTy)); + // Check isFloatTy(). + EXPECT_TRUE(CHK(getFloatTy, isFloatTy)); + // Check isDoubleTy(). + EXPECT_TRUE(CHK(getDoubleTy, isDoubleTy)); + // Check isX86_FP80Ty(). + EXPECT_TRUE(CHK(getX86_FP80Ty, isX86_FP80Ty)); + // Check isFP128Ty(). + EXPECT_TRUE(CHK(getFP128Ty, isFP128Ty)); + // Check isPPC_FP128Ty(). + EXPECT_TRUE(CHK(getPPC_FP128Ty, isPPC_FP128Ty)); + // Check isIEEELikeFPTy(). + EXPECT_TRUE(CHK(getFloatTy, isIEEELikeFPTy)); + // Check isFloatingPointTy(). + EXPECT_TRUE(CHK(getFloatTy, isFloatingPointTy)); + EXPECT_TRUE(CHK(getDoubleTy, isFloatingPointTy)); + // Check isMultiUnitFPType(). + EXPECT_TRUE(CHK(getPPC_FP128Ty, isMultiUnitFPType)); + EXPECT_FALSE(CHK(getFloatTy, isMultiUnitFPType)); + // Check getFltSemantics(). + EXPECT_EQ(&sandboxir::Type::getFloatTy(Ctx)->getFltSemantics(), + &llvm::Type::getFloatTy(C)->getFltSemantics()); + // Check isX86_AMXTy(). + EXPECT_TRUE(CHK(getX86_AMXTy, isX86_AMXTy)); + // Check isTargetExtTy(). + EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "foo"))->isTargetExtTy()); + // Check isScalableTargetExtTy(). + EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "aarch64.svcount")) + ->isScalableTargetExtTy()); + // Check isScalableTy(). + EXPECT_TRUE(Ctx.getType(llvm::ScalableVectorType::get(LLVMInt32Ty, 2u)) + ->isScalableTy()); + // Check isFPOrFPVectorTy(). + EXPECT_TRUE(CHK(getFloatTy, isFPOrFPVectorTy)); + EXPECT_FALSE(CHK(getInt32Ty, isFPOrFPVectorTy)); + // Check isLabelTy(). + EXPECT_TRUE(CHK(getLabelTy, isLabelTy)); + // Check isMetadataTy(). + EXPECT_TRUE(CHK(getMetadataTy, isMetadataTy)); + // Check isTokenTy(). + EXPECT_TRUE(CHK(getTokenTy, isTokenTy)); + // Check isIntegerTy(). + EXPECT_TRUE(CHK(getInt32Ty, isIntegerTy)); + EXPECT_FALSE(CHK(getFloatTy, isIntegerTy)); + // Check isIntegerTy(Bitwidth). + EXPECT_TRUE(LLVMInt32Ty->isIntegerTy(32u)); + EXPECT_FALSE(LLVMInt32Ty->isIntegerTy(31u)); + EXPECT_FALSE(Ctx.getType(llvm::Type::getFloatTy(C))->isIntegerTy(32u)); + // Check isIntOrIntVectorTy(). + EXPECT_TRUE(LLVMInt32Ty->isIntOrIntVectorTy()); + EXPECT_TRUE(Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8)) + ->isIntOrIntVectorTy()); + EXPECT_FALSE(Ctx.getType(LLVMFloatTy)->isIntOrIntVectorTy()); + EXPECT_FALSE(Ctx.getType(llvm::FixedVectorType::get(LLVMFloatTy, 8)) + ->isIntOrIntVectorTy()); + // Check isIntOrPtrTy(). + EXPECT_TRUE(Int32Ty->isIntOrPtrTy()); + EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isIntOrPtrTy()); + EXPECT_FALSE(FloatTy->isIntOrPtrTy()); + // Check isFunctionTy(). + EXPECT_TRUE(Ctx.getType(llvm::FunctionType::get(LLVMInt32Ty, {}, false)) + ->isFunctionTy()); + // Check isStructTy(). + EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isStructTy()); + // Check isArrayTy(). + EXPECT_TRUE(Ctx.getType(llvm::ArrayType::get(LLVMInt32Ty, 10))->isArrayTy()); + // Check isPointerTy(). + EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isPointerTy()); + // Check isPtrOrPtrVectroTy(). + EXPECT_TRUE( + Ctx.getType(llvm::FixedVectorType::get(llvm::PointerType::get(C, 0u), 8u)) + ->isPtrOrPtrVectorTy()); + // Check isVectorTy(). + EXPECT_TRUE( + Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->isVectorTy()); + // Check canLosslesslyBitCastTo(). + auto *VecTy32x4 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 4u)); + auto *VecTy32x2 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 2u)); + auto *VecTy8x16 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt8Ty, 16u)); + EXPECT_TRUE(VecTy32x4->canLosslesslyBitCastTo(VecTy8x16)); + EXPECT_FALSE(VecTy32x4->canLosslesslyBitCastTo(VecTy32x2)); + // Check isEmptyTy(). + EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isEmptyTy()); + // Check isFirstClassType(). + EXPECT_TRUE(Int32Ty->isFirstClassType()); + // Check isSingleValueType(). + EXPECT_TRUE(Int32Ty->isSingleValueType()); + // Check isAggregateType(). + EXPECT_FALSE(Int32Ty->isAggregateType()); + // Check isSized(). + SmallPtrSet Visited; + EXPECT_TRUE(Int32Ty->isSized(&Visited)); + // Check getPrimitiveSizeInBits(). + EXPECT_EQ(VecTy32x2->getPrimitiveSizeInBits(), 32u * 2); + // Check getScalarSizeInBits(). + EXPECT_EQ(VecTy32x2->getScalarSizeInBits(), 32u); + // Check getFPMantissaWidth(). + EXPECT_EQ(FloatTy->getFPMantissaWidth(), LLVMFloatTy->getFPMantissaWidth()); + // Check isIEEE(). + EXPECT_EQ(FloatTy->isIEEE(), LLVMFloatTy->isIEEE()); + // Check getScalarType(). + EXPECT_EQ( + Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->getScalarType(), + Int32Ty); + +#define CHK_GET(TY) \ + EXPECT_EQ(Ctx.getType(llvm::Type::get##TY##Ty(C)), \ + sandboxir::Type::get##TY##Ty(Ctx)) + // Check getInt64Ty(). + CHK_GET(Int64); + // Check getInt32Ty(). + CHK_GET(Int32); + // Check getInt16Ty(). + CHK_GET(Int16); + // Check getInt8Ty(). + CHK_GET(Int8); + // Check getInt1Ty(). + CHK_GET(Int1); + // Check getDoubleTy(). + CHK_GET(Double); + // Check getFloatTy(). + CHK_GET(Float); +} + +TEST_F(SandboxTypeTest, PointerType) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + // Check classof(), creation. + auto *PtrTy = cast(F->getArg(0)->getType()); + // Check get(ElementType, AddressSpace). + auto *NewPtrTy = + sandboxir::PointerType::get(sandboxir::Type::getInt32Ty(Ctx), 0u); + EXPECT_EQ(NewPtrTy, PtrTy); + // Check get(Ctx, AddressSpace). + auto *NewPtrTy2 = sandboxir::PointerType::get(Ctx, 0u); + EXPECT_EQ(NewPtrTy2, PtrTy); +} + +TEST_F(SandboxTypeTest, VectorType) { + parseIR(C, R"IR( +define void @foo(<2 x i8> %v0) { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + // Check classof(), creation. + [[maybe_unused]] auto *VecTy = + cast(F->getArg(0)->getType()); +} + +TEST_F(SandboxTypeTest, FunctionType) { + parseIR(C, R"IR( +define void @foo() { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + // Check classof(), creation. + [[maybe_unused]] auto *FTy = + cast(F->getFunctionType()); +}