Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding splitdouble HLSL function #109331

Merged
merged 18 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4871,6 +4871,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_splitdouble"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
82 changes: 82 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CGObjCRuntime.h"
#include "CGOpenCLRuntime.h"
#include "CGRecordLayout.h"
#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
Expand All @@ -25,8 +26,10 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/OSLog.h"
#include "clang/AST/OperationKinds.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetBuiltins.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TargetOptions.h"
Expand Down Expand Up @@ -67,6 +70,7 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
#include <utility>

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -95,6 +99,76 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
I->addAnnotationMetadata("auto-init");
}

static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));

CallArgList Args;
LValue Op1TmpLValue =
CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
LValue Op2TmpLValue =
CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());

if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
Args.reverseWritebacks();

Value *LowBits = nullptr;
Value *HighBits = nullptr;

if (CGF->CGM.getTarget().getTriple().isDXIL()) {

llvm::Type *RetElementTy = CGF->Int32Ty;
if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
RetElementTy = llvm::VectorType::get(
CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);

CallInst *CI = CGF->Builder.CreateIntrinsic(
RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");

LowBits = CGF->Builder.CreateExtractValue(CI, 0);
HighBits = CGF->Builder.CreateExtractValue(CI, 1);

} else {
// For Non DXIL targets we generate the instructions.

if (!Op0->getType()->isVectorTy()) {
FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);

LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
} else {
int NumElements = 1;
if (const auto *VecTy =
E->getArg(0)->getType()->getAs<clang::VectorType>())
NumElements = VecTy->getNumElements();

FixedVectorType *Uint32VecTy =
FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
if (NumElements == 1) {
LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
} else {
SmallVector<int> EvenMask, OddMask;
for (int I = 0, E = NumElements; I != E; ++I) {
EvenMask.push_back(I * 2);
OddMask.push_back(I * 2 + 1);
}
LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
}
}
}
CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
auto *LastInst =
CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
CGF->EmitWritebacks(Args);
return LastInst;
}

/// getBuiltinLibFunction - Given a builtin id for a function like
/// "__builtin_fabsf", return a Function* for "fabsf".
llvm::Constant *CodeGenModule::getBuiltinLibFunction(const FunctionDecl *FD,
Expand Down Expand Up @@ -18959,6 +19033,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.radians");
}
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {

assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
"asuint operands types mismatch");
return handleHlslSplitdouble(E, this);
}
}
return nullptr;
}
Expand Down
14 changes: 7 additions & 7 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Path.h"
#include "llvm/Transforms/Utils/Local.h"
#include <optional>
using namespace clang;
Expand Down Expand Up @@ -4207,12 +4208,6 @@ static void emitWriteback(CodeGenFunction &CGF,
CGF.EmitBlock(contBB);
}

static void emitWritebacks(CodeGenFunction &CGF,
const CallArgList &args) {
for (const auto &I : args.writebacks())
emitWriteback(CGF, I);
}

static void deactivateArgCleanupsBeforeCall(CodeGenFunction &CGF,
const CallArgList &CallArgs) {
ArrayRef<CallArgList::CallArgCleanup> Cleanups =
Expand Down Expand Up @@ -4681,6 +4676,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
IsUsed = true;
}

void CodeGenFunction::EmitWritebacks(const CallArgList &args) {
for (const auto &I : args.writebacks())
emitWriteback(*this, I);
}

void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
QualType type) {
DisableDebugLocationUpdates Dis(*this, E);
Expand Down Expand Up @@ -5897,7 +5897,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// Emit any call-associated writebacks immediately. Arguably this
// should happen after any return-value munging.
if (CallArgs.hasWritebacks())
emitWritebacks(*this, CallArgs);
EmitWritebacks(CallArgs);

// The stack cleanup for inalloca arguments has to run out of the normal
// lexical order, so deactivate it and run it manually here.
Expand Down
13 changes: 10 additions & 3 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5460,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
return getOrCreateOpaqueLValueMapping(e);
}

void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
CallArgList &Args, QualType Ty) {

std::pair<LValue, LValue>
CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
Comment on lines +5463 to +5464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is EmitHLSLOutArgLValues used any other than in CodeGenFunction::EmitHLSLOutArgExpr at this point? I think we can avoid splitting this method in two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@llvm-beanz, any reason we couldn't merge those methods together?

// Emitting the casted temporary through an opaque value.
LValue BaseLV = EmitLValue(E->getArgLValue());
OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
Expand All @@ -5476,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
TempLV);

OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
return std::make_pair(BaseLV, TempLV);
}

LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
CallArgList &Args, QualType Ty) {

auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);

llvm::Value *Addr = TempLV.getAddress().getBasePointer();
llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
Expand All @@ -5488,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
return TempLV;
}

LValue
Expand Down
10 changes: 8 additions & 2 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
LValue EmitHLSLArrayAssignLValue(const BinaryOperator *E);
void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
QualType Ty);

std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
QualType Ty);
LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
QualType Ty);

Address EmitExtVectorElementLValue(LValue V);

Expand Down Expand Up @@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache {
SourceLocation ArgLoc, AbstractCallee AC,
unsigned ParmNum);

/// EmitWriteback - Emit callbacks for function.
void EmitWritebacks(const CallArgList &Args);

/// EmitCallArg - Emit a single call argument.
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);

Expand Down
18 changes: 18 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,24 @@ template <typename T> constexpr uint asuint(T F) {
return __detail::bit_cast<uint, T>(F);
}

//===----------------------------------------------------------------------===//
// asuint splitdouble builtins
//===----------------------------------------------------------------------===//

/// \fn void asuint(double D, out uint lowbits, out int highbits)
/// \brief Split and interprets the lowbits and highbits of double D into uints.
/// \param D The input double.
/// \param lowbits The output lowbits of D.
/// \param highbits The output highbits of D.
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double, out uint, out uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double2, out uint2, out uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double3, out uint3, out uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double4, out uint4, out uint4);

//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//
Expand Down
74 changes: 55 additions & 19 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,18 +1698,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}

static bool CheckArgsTypesAreCorrect(
bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
return true;
}
return false;
}

bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
QualType PassedType = TheCall->getArg(i)->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
Expand All @@ -1720,8 +1729,8 @@ static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
}

static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
Expand All @@ -1732,8 +1741,19 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
: PassedType;
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
}

static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
auto *Arg = TheCall->getArg(ArgIndex);
SourceLocation OrigLoc = Arg->getExprLoc();
if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
Expr::MLV_Valid)
return false;
S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
return true;
}

static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
Expand All @@ -1742,24 +1762,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
return VecTy->getElementType()->isDoubleType();
return false;
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkDoubleVector);
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkDoubleVector);
}
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
!PassedType->hasFloatingRepresentation();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.IntTy,
checkAllSignedTypes);
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
checkAllSignedTypes);
}

static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasUnsignedIntegerRepresentation();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
checkAllUnsignedTypes);
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
checkAllUnsignedTypes);
}

static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
Expand Down Expand Up @@ -2074,6 +2094,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
1) ||
CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
2))
return true;
joaosaffran marked this conversation as resolved.
Show resolved Hide resolved

if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
CheckModifiableLValue(&SemaRef, TheCall, 2))
return true;
break;
}
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
Expand Down
Loading