Skip to content

Commit

Permalink
Add tests for call directly to builtin
Browse files Browse the repository at this point in the history
Add more robustness to SemaChecking
  • Loading branch information
farzonl committed Feb 21, 2024
1 parent 595ce41 commit 3524651
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 83 deletions.
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4526,7 +4526,7 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {

def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14121,6 +14121,8 @@ class Sema final {

bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);

bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
Expand Down
10 changes: 6 additions & 4 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17914,26 +17914,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy()) {
return Builder.CreateFMul(Op0, Op1, "dx.dot");
return Builder.CreateFMul(Op0, Op1, "dx.dot");
}

if (T0->isIntegerTy()) {
return Builder.CreateMul(Op0, Op1, "dx.dot");
return Builder.CreateMul(Op0, Op1, "dx.dot");
}
// Bools should have been promoted
assert(
false &&
"Dot product on a scalar is only supported on integers and floats.");
}
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");

// NOTE: this assert will need to be revisited after overload resoltion
// PR merges.
// A vector sext or sitofp should have happened
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");

auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");

Expand Down
225 changes: 172 additions & 53 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5163,69 +5163,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}

// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2)) {
bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
unsigned NumArgs = TheCall->getNumArgs();

for (unsigned i = 0; i < NumArgs; ++i) {
ExprResult A = TheCall->getArg(i);
if (!A.get()->getType()->isBooleanType())
return false;
}
// if we got here all args are bool
for (unsigned i = 0; i < NumArgs; ++i) {
ExprResult A = TheCall->getArg(i);
ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
Sema::AA_Converting);
if (ResA.isInvalid())
return true;
}
Expr *Arg0 = TheCall->getArg(0);
QualType ArgTy0 = Arg0->getType();
TheCall->setArg(0, ResA.get());
}
return false;
}

Expr *Arg1 = TheCall->getArg(1);
QualType ArgTy1 = Arg1->getType();
int overloadOrder(Sema *S, QualType ArgTyA) {
auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
switch (kind) {
case BuiltinType::Short:
case BuiltinType::UShort:
return 1;
case BuiltinType::Int:
case BuiltinType::UInt:
return 2;
case BuiltinType::Long:
case BuiltinType::ULong:
return 3;
case BuiltinType::LongLong:
case BuiltinType::ULongLong:
return 4;
case BuiltinType::Float16:
case BuiltinType::Half:
return 5;
case BuiltinType::Float:
return 6;
default:
break;
}
return 0;
}

auto *VecTy0 = ArgTy0->getAs<VectorType>();
auto *VecTy1 = ArgTy1->getAs<VectorType>();
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
QualType VecTyAElem = VecTyA->getElementType();
QualType VecTyBElem = VecTyB->getElementType();
int vecAElemWidth = overloadOrder(S, VecTyAElem);
int vecBElemWidth = overloadOrder(S, VecTyBElem);
return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
}

// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
return true;
}
void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();

// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
return true;
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
if (VecTyA == nullptr && VecTyB == nullptr)
return;
if (VecTyA == nullptr || VecTyB == nullptr)
return;
if (VecTyA->getNumElements() == VecTyB->getNumElements())
return;

Expr *LargerArg = B.get();
Expr *SmallerArg = A.get();
int largerIndex = 1;
if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
LargerArg = A.get();
SmallerArg = B.get();
largerIndex = 0;
}
S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
<< LargerArg->getType() << SmallerArg->getType()
<< LargerArg->getSourceRange() << SmallerArg->getSourceRange();
ExprResult ResLargerArg = S->ImpCastExprToType(
LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
TheCall->setArg(largerIndex, ResLargerArg.get());
return;
}

bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();

auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;
if (VecTyA && VecTyB) {
if (VecTyA->getElementType() == VecTyB->getElementType()) {
TheCall->setType(VecTyA->getElementType());
return false;
}
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
if (CastType == ArgTyA) {
ExprResult ResB = S->SemaConvertVectorExpr(
B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
B.get()->getBeginLoc());
TheCall->setArg(1, ResB.get());
TheCall->setType(VecTyA->getElementType());
return false;
}

if (VecTy0 == nullptr && VecTy1 == nullptr) {
if (ArgTy0 != ArgTy1) {
return true;
} else {
return false;
}
if (CastType == ArgTyB) {
ExprResult ResA = S->SemaConvertVectorExpr(
A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
A.get()->getBeginLoc());
TheCall->setArg(0, ResA.get());
TheCall->setType(VecTyB->getElementType());
return false;
}
return false;
}

if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
(VecTy0 != nullptr && VecTy1 == nullptr)) {
if (VecTyB) {
// Convert to the vector result type
ExprResult ResA = A;
if (VecTyB->getElementType() != ArgTyA)
ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
CK_FloatingCast);
ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
TheCall->setArg(0, ResA.get());
}
if (VecTyA) {
ExprResult ResB = B;
if (VecTyA->getElementType() != ArgTyB)
ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
CK_FloatingCast);
ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
TheCall->setArg(1, ResB.get());
}
return false;
}

Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2))
return true;
}

if (VecTy0->getElementType() != VecTy1->getElementType()) {
// Note: This case should never happen. If type promotion occurs
// then element types won't be different. This diag error is here
// b\c EmitHLSLBuiltinExpr asserts on this case.
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
if (PromoteBoolsToInt(this, TheCall))
return true;
}
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
if (PromoteVectorElementCallArgs(this, TheCall))
return true;
PromoteVectorArgTruncation(this, TheCall);
if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
}
break;
}
}
Expand Down Expand Up @@ -19669,15 +19767,37 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
}

bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
QualType Res;
bool result = SemaBuiltinVectorMath(TheCall, Res);
if (result)
return true;
TheCall->setType(Res);
return false;
}

bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
QualType Res;
bool result = SemaBuiltinVectorMath(TheCall, Res);
if (result)
return true;

if (auto *VecTy0 = Res->getAs<VectorType>()) {
TheCall->setType(VecTy0->getElementType());
} else {
TheCall->setType(Res);
}
return false;
}

bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
if (checkArgCount(*this, TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
// Do standard promotions between the two arguments, returning their common
// type.
QualType Res =
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
if (A.isInvalid() || B.isInvalid())
return true;

Expand All @@ -19694,7 +19814,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {

TheCall->setArg(0, A.get());
TheCall->setArg(1, B.get());
TheCall->setType(Res);
return false;
}

Expand Down
Loading

0 comments on commit 3524651

Please sign in to comment.