Skip to content

Commit

Permalink
[HLSL] Vector Usual Arithmetic Conversions (#108659)
Browse files Browse the repository at this point in the history
HLSL has a different set of usual arithmetic conversions for vector
types to resolve a common type for binary operator expressions.

This PR implements the current spec proposal from:
microsoft/hlsl-specs#311

There is one case that may need additional handling for implicitly
truncating `vector<T,1>` to `T` early to allow other transformations.

Fixes #106253
  • Loading branch information
llvm-beanz authored and puja2196 committed Oct 2, 2024
1 parent 2a898de commit 9777451
Show file tree
Hide file tree
Showing 7 changed files with 594 additions and 4 deletions.
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12395,6 +12395,9 @@ def err_hlsl_operator_unsupported : Error<

def err_hlsl_param_qualifier_mismatch :
Error<"conflicting parameter qualifier %0 on parameter %1">;
def err_hlsl_vector_compound_assignment_truncation : Error<
"left hand operand of type %0 to compound assignment cannot be truncated "
"when used with right hand operand of type %1">;

def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,7 @@ def flax_vector_conversions_EQ : Joined<["-"], "flax-vector-conversions=">, Grou
"LangOptions::LaxVectorConversionKind::Integer",
"LangOptions::LaxVectorConversionKind::All"]>,
MarshallingInfoEnum<LangOpts<"LaxVectorConversions">,
open_cl.KeyPath #
!strconcat("(", open_cl.KeyPath, " || ", hlsl.KeyPath, ")") #
" ? LangOptions::LaxVectorConversionKind::None" #
" : LangOptions::LaxVectorConversionKind::All">;
def flax_vector_conversions : Flag<["-"], "flax-vector-conversions">, Group<f_Group>,
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -7423,7 +7423,8 @@ class Sema final : public SemaBase {
SourceLocation Loc,
BinaryOperatorKind Opc);
QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
SourceLocation Loc);
SourceLocation Loc,
BinaryOperatorKind Opc);

/// Context in which we're performing a usual arithmetic conversion.
enum ArithConvKind {
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class SemaHLSL : public SemaBase {
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);

QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType, QualType RHSType,
bool IsCompAssign);
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc);

void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
Expand Down
18 changes: 16 additions & 2 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10133,6 +10133,10 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
const VectorType *RHSVecType = RHSType->getAs<VectorType>();
assert(LHSVecType || RHSVecType);

if (getLangOpts().HLSL)
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
IsCompAssign);

// AltiVec-style "vector bool op vector bool" combinations are allowed
// for some operators but not others.
if (!AllowBothBool && LHSVecType &&
Expand Down Expand Up @@ -12863,7 +12867,8 @@ static void diagnoseXorMisusedAsPow(Sema &S, const ExprResult &XorLHS,
}

QualType Sema::CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
SourceLocation Loc) {
SourceLocation Loc,
BinaryOperatorKind Opc) {
// Ensure that either both operands are of the same vector type, or
// one operand is of a vector type and the other is of its element type.
QualType vType = CheckVectorOperands(LHS, RHS, Loc, false,
Expand All @@ -12883,6 +12888,15 @@ QualType Sema::CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
if (!getLangOpts().CPlusPlus &&
!(isa<ExtVectorType>(vType->getAs<VectorType>())))
return InvalidLogicalVectorOperands(Loc, LHS, RHS);
// Beginning with HLSL 2021, HLSL disallows logical operators on vector
// operands and instead requires the use of the `and`, `or`, `any`, `all`, and
// `select` functions.
if (getLangOpts().HLSL &&
getLangOpts().getHLSLVersion() >= LangOptionsBase::HLSL_2021) {
(void)InvalidOperands(Loc, LHS, RHS);
HLSL().emitLogicalOperatorFixIt(LHS.get(), RHS.get(), Opc);
return QualType();
}

return GetSignedVectorType(LHS.get()->getType());
}
Expand Down Expand Up @@ -13054,7 +13068,7 @@ inline QualType Sema::CheckLogicalOperands(ExprResult &LHS, ExprResult &RHS,
// Check vector operands differently.
if (LHS.get()->getType()->isVectorType() ||
RHS.get()->getType()->isVectorType())
return CheckVectorLogicalOperands(LHS, RHS, Loc);
return CheckVectorLogicalOperands(LHS, RHS, Loc, Opc);

bool EnumConstantInBoolContext = false;
for (const ExprResult &HS : {LHS, RHS}) {
Expand Down
188 changes: 188 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,194 @@ void SemaHLSL::DiagnoseAttrStageMismatch(
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}

template <CastKind Kind>
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
if (const auto *VTy = Ty->getAs<VectorType>())
Ty = VTy->getElementType();
Ty = S.getASTContext().getExtVectorType(Ty, Sz);
E = S.ImpCastExprToType(E.get(), Ty, Kind);
}

template <CastKind Kind>
static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
E = S.ImpCastExprToType(E.get(), Ty, Kind);
return Ty;
}

static QualType handleFloatVectorBinOpConversion(
Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
bool LHSFloat = LElTy->isRealFloatingType();
bool RHSFloat = RElTy->isRealFloatingType();

if (LHSFloat && RHSFloat) {
if (IsCompAssign ||
SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);

return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
}

if (LHSFloat)
return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);

assert(RHSFloat);
if (IsCompAssign)
return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);

return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
}

static QualType handleIntegerVectorBinOpConversion(
Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {

int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
auto &Ctx = SemaRef.getASTContext();

// If both types have the same signedness, use the higher ranked type.
if (LHSSigned == RHSSigned) {
if (IsCompAssign || IntOrder >= 0)
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);

return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
}

// If the unsigned type has greater than or equal rank of the signed type, use
// the unsigned type.
if (IntOrder != (LHSSigned ? 1 : -1)) {
if (IsCompAssign || RHSSigned)
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
}

// At this point the signed type has higher rank than the unsigned type, which
// means it will be the same size or bigger. If the signed type is bigger, it
// can represent all the values of the unsigned type, so select it.
if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
if (IsCompAssign || LHSSigned)
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
}

// This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
// to C/C++ leaking through. The place this happens today is long vs long
// long. When arguments are vector<unsigned long, N> and vector<long long, N>,
// the long long has higher rank than long even though they are the same size.

// If this is a compound assignment cast the right hand side to the left hand
// side's type.
if (IsCompAssign)
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);

// If this isn't a compound assignment we convert to unsigned long long.
QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
QualType NewTy = Ctx.getExtVectorType(
ElTy, RHSType->castAs<VectorType>()->getNumElements());
(void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);

return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
}

static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
QualType SrcTy) {
if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
return CK_FloatingCast;
if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
return CK_IntegralCast;
if (DestTy->isRealFloatingType())
return CK_IntegralToFloating;
assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
return CK_FloatingToIntegral;
}

QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType,
QualType RHSType,
bool IsCompAssign) {
const auto *LVecTy = LHSType->getAs<VectorType>();
const auto *RVecTy = RHSType->getAs<VectorType>();
auto &Ctx = getASTContext();

// If the LHS is not a vector and this is a compound assignment, we truncate
// the argument to a scalar then convert it to the LHS's type.
if (!LVecTy && IsCompAssign) {
QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
RHSType = RHS.get()->getType();
if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
return LHSType;
RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
getScalarCastKind(Ctx, LHSType, RHSType));
return LHSType;
}

unsigned EndSz = std::numeric_limits<unsigned>::max();
unsigned LSz = 0;
if (LVecTy)
LSz = EndSz = LVecTy->getNumElements();
if (RVecTy)
EndSz = std::min(RVecTy->getNumElements(), EndSz);
assert(EndSz != std::numeric_limits<unsigned>::max() &&
"one of the above should have had a value");

// In a compound assignment, the left operand does not change type, the right
// operand is converted to the type of the left operand.
if (IsCompAssign && LSz != EndSz) {
Diag(LHS.get()->getBeginLoc(),
diag::err_hlsl_vector_compound_assignment_truncation)
<< LHSType << RHSType;
return QualType();
}

if (RVecTy && RVecTy->getNumElements() > EndSz)
castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);

if (!RVecTy)
castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
if (!IsCompAssign && !LVecTy)
castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);

// If we're at the same type after resizing we can stop here.
if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
return Ctx.getCommonSugaredType(LHSType, RHSType);

QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
QualType RElTy = RHSType->castAs<VectorType>()->getElementType();

// Handle conversion for floating point vectors.
if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
LElTy, RElTy, IsCompAssign);

assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
"HLSL Vectors can only contain integer or floating point types");
return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
LElTy, RElTy, IsCompAssign);
}

void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
BinaryOperatorKind Opc) {
assert((Opc == BO_LOr || Opc == BO_LAnd) &&
"Called with non-logical operator");
llvm::SmallVector<char, 256> Buff;
llvm::raw_svector_ostream OS(Buff);
PrintingPolicy PP(SemaRef.getLangOpts());
StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
OS << NewFnName << "(";
LHS->printPretty(OS, nullptr, PP);
OS << ", ";
RHS->printPretty(OS, nullptr, PP);
OS << ")";
SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
<< NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
}

void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
llvm::VersionTuple SMVersion =
getASTContext().getTargetInfo().getTriple().getOSVersion();
Expand Down
Loading

0 comments on commit 9777451

Please sign in to comment.