Skip to content

Commit

Permalink
[HLSL] Implementation of dot intrinsic
Browse files Browse the repository at this point in the history
This change implements #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.
  • Loading branch information
farzonl committed Feb 17, 2024
1 parent 3bef17e commit 586ee65
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 5 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void*(unsigned char)";
}

def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -10267,6 +10267,8 @@ def err_vec_builtin_non_vector : Error<
"first two arguments to %0 must be vectors">;
def err_vec_builtin_incompatible_vector : Error<
"first two arguments to %0 must have the same type">;
def err_vec_builtin_incompatible_size : Error<
"first two arguments to %0 must have the same size">;
def err_vsx_builtin_nonconstant_argument : Error<
"argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;

Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14055,6 +14055,7 @@ class Sema final {
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
Expand Down
49 changes: 49 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsBPF.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsHexagon.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsPowerPC.h"
Expand Down Expand Up @@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
}

// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
return RValue::get(V);

if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
return EmitHipStdParUnsupportedBuiltin(this, FD);

Expand Down Expand Up @@ -17896,6 +17901,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
if (!getLangOpts().HLSL)
return nullptr;

switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy()) {
return Builder.CreateFMul(Op0, Op1, "dx.dot");
}

if (T0->isIntegerTy()) {
return Builder.CreateMul(Op0, Op1, "dx.dot");
}
assert(
false &&
"Dot product on a scalar is only supported on integers and floats.");
}
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");

// NOTE: this assert will need to be revisited after overload resoltion
// PR merges.
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");

auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");

return Builder.CreateIntrinsic(
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
}
return nullptr;
}

Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
const CallExpr *E);
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
Expand Down
92 changes: 92 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,98 @@ double3 cos(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
double4 cos(double4);

//===----------------------------------------------------------------------===//
// dot product builtins
//===----------------------------------------------------------------------===//
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half4, half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t, int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t2, int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t4, int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t, uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t2, uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t4, uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float4, float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
double dot(double, double);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int, int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int2, int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int4, int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint, uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint2, uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint4, uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t, int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t2, int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t4, int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t, uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t2, uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t4, uint64_t4);

//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
Expand Down
84 changes: 79 additions & 5 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
// not a valid type, emit an error message and return true. Otherwise return
// false.
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
QualType Ty) {
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
QualType ArgTy, int ArgIndex) {
if (!ArgTy->getAs<VectorType>() &&
!ConstantMatrixType::isValidElementType(ArgTy)) {
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
}

return false;
Expand Down Expand Up @@ -2958,6 +2959,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
}
}

if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
return ExprError();
}

// Since the target specific builtins for each arch overlap, only check those
// of the arch we are compiling for.
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
Expand Down Expand Up @@ -5158,6 +5163,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}

// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2)) {
return true;
}
Expr *Arg0 = TheCall->getArg(0);
QualType ArgTy0 = Arg0->getType();

Expr *Arg1 = TheCall->getArg(1);
QualType ArgTy1 = Arg1->getType();

auto *VecTy0 = ArgTy0->getAs<VectorType>();
auto *VecTy1 = ArgTy1->getAs<VectorType>();
SourceLocation BuiltinLoc = TheCall->getBeginLoc();

// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
return true;
}

// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
return true;
}

if (VecTy0 == nullptr && VecTy1 == nullptr) {
if (ArgTy0 != ArgTy1) {
return true;
} else {
return false;
}
}

if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
(VecTy0 != nullptr && VecTy1 == nullptr)) {

Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
return true;
}

if (VecTy0->getElementType() != VecTy1->getElementType()) {
// Note: This case should never happen. If type promotion occurs
// then element types won't be different. This diag error is here
// b\c EmitHLSLBuiltinExpr asserts on this case.
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
return true;
}
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
return true;
}
break;
}
}
return false;
}

bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
// position of memory order and scope arguments in the builtin
Expand Down Expand Up @@ -19583,7 +19657,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
TheCall->setArg(0, A.get());
QualType TyA = A.get()->getType();

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setType(TyA);
Expand Down Expand Up @@ -19611,7 +19685,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
diag::err_typecheck_call_different_arg_types)
<< TyA << TyB;

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setArg(0, A.get());
Expand Down
Loading

0 comments on commit 586ee65

Please sign in to comment.