Skip to content

Commit b7c3836

Browse files
committed
[HLSL] Add matrix constructors using initalizer lists
fixes #159434 In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists. This supports the following HLSL syntax: (1) HLSL matrices support constructor syntax (2) HLSL matrices are expanded to constituate components in constructor using the same initalizer list behavior defined in transformInitList allows us to support struct element initalization via HLSLElementwiseCast
1 parent e29cf8e commit b7c3836

File tree

7 files changed

+608
-31
lines changed

7 files changed

+608
-31
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,9 +2418,9 @@ def err_init_conversion_failed : Error<
24182418
"cannot initialize %select{a variable|a parameter|template parameter|"
24192419
"return object|statement expression result|an "
24202420
"exception object|a member subobject|an array element|a new value|a value|a "
2421-
"base class|a constructor delegation|a vector element|a block element|a "
2422-
"block element|a complex element|a lambda capture|a compound literal "
2423-
"initializer|a related result|a parameter of CF audited function|a "
2421+
"base class|a constructor delegation|a vector element|a matrix element|a "
2422+
"block element|a block element|a complex element|a lambda capture|a compound"
2423+
" literal initializer|a related result|a parameter of CF audited function|a "
24242424
"structured binding|a member subobject}0 "
24252425
"%diff{of type $ with an %select{rvalue|lvalue}2 of type $|"
24262426
"with an %select{rvalue|lvalue}2 of incompatible type}1,3"
@@ -6546,9 +6546,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
65466546
def err_variable_object_no_init : Error<
65476547
"variable-sized object may not be initialized">;
65486548
def err_excess_initializers : Error<
6549-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
6549+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
65506550
def ext_excess_initializers : ExtWarn<
6551-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
6551+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
65526552
InGroup<ExcessInitializers>;
65536553
def err_excess_initializers_for_sizeless_type : Error<
65546554
"excess elements in initializer for indivisible sizeless type %0">;
@@ -11089,8 +11089,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
1108911089
def err_second_argument_to_cwsc_not_pointer : Error<
1109011090
"second argument to __builtin_call_with_static_chain must be of pointer type">;
1109111091

11092-
def err_vector_incorrect_num_elements : Error<
11093-
"%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
11092+
def err_tensor_incorrect_num_elements : Error<
11093+
"%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
1109411094
def err_altivec_empty_initializer : Error<"expected initializer">;
1109511095

1109611096
def err_vector_incorrect_bit_count : Error<

clang/include/clang/Sema/Initialization.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
9191
/// or vector.
9292
EK_VectorElement,
9393

94+
/// The entity being initialized is an element of a matrix.
95+
/// or matrix.
96+
EK_MatrixElement,
97+
9498
/// The entity being initialized is a field of block descriptor for
9599
/// the copied-in c++ object.
96100
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
205209
/// virtual base.
206210
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
207211

208-
/// When Kind == EK_ArrayElement, EK_VectorElement, or
209-
/// EK_ComplexElement, the index of the array or vector element being
212+
/// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
213+
/// or EK_ComplexElement, the index of the array or vector element being
210214
/// initialized.
211215
unsigned Index;
212216

@@ -536,15 +540,15 @@ class alignas(8) InitializedEntity {
536540
/// element's index.
537541
unsigned getElementIndex() const {
538542
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
539-
getKind() == EK_ComplexElement);
543+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
540544
return Index;
541545
}
542546

543547
/// If this is already the initializer for an array or vector
544548
/// element, sets the element index.
545549
void setElementIndex(unsigned Index) {
546550
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
547-
getKind() == EK_ComplexElement);
551+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
548552
this->Index = Index;
549553
}
550554

clang/lib/Sema/CheckExprLifetime.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ getEntityLifetime(const InitializedEntity *Entity,
155155
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
156156
case InitializedEntity::EK_LambdaCapture:
157157
case InitializedEntity::EK_VectorElement:
158+
case InitializedEntity::EK_MatrixElement:
158159
case InitializedEntity::EK_ComplexElement:
159160
return {nullptr, LK_FullExpression};
160161

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "clang/AST/Expr.h"
2222
#include "clang/AST/HLSLResource.h"
2323
#include "clang/AST/Type.h"
24+
#include "clang/AST/TypeBase.h"
2425
#include "clang/AST/TypeLoc.h"
2526
#include "clang/Basic/Builtins.h"
2627
#include "clang/Basic/DiagnosticSema.h"
@@ -3351,6 +3352,11 @@ static void BuildFlattenedTypeList(QualType BaseTy,
33513352
List.insert(List.end(), VT->getNumElements(), VT->getElementType());
33523353
continue;
33533354
}
3355+
if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
3356+
List.insert(List.end(), MT->getNumElementsFlattened(),
3357+
MT->getElementType());
3358+
continue;
3359+
}
33543360
if (const auto *RD = T->getAsCXXRecordDecl()) {
33553361
if (RD->isStandardLayout())
33563362
RD = RD->getStandardLayoutBaseWithFields();
@@ -4149,6 +4155,32 @@ class InitListTransformer {
41494155
}
41504156
return true;
41514157
}
4158+
if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4159+
unsigned Rows = MTy->getNumRows();
4160+
unsigned Cols = MTy->getNumColumns();
4161+
QualType ElemTy = MTy->getElementType();
4162+
4163+
for (unsigned C = 0; C < Cols; ++C) {
4164+
for (unsigned R = 0; R < Rows; ++R) {
4165+
// row index literal
4166+
Expr *RowIdx = IntegerLiteral::Create(
4167+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
4168+
E->getBeginLoc());
4169+
// column index literal
4170+
Expr *ColIdx = IntegerLiteral::Create(
4171+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
4172+
E->getBeginLoc());
4173+
ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
4174+
E, RowIdx, ColIdx, E->getEndLoc());
4175+
if (ElExpr.isInvalid())
4176+
return false;
4177+
if (!buildInitializerListImpl(ElExpr.get()))
4178+
return false;
4179+
ElExpr.get()->setType(ElemTy);
4180+
}
4181+
}
4182+
return true;
4183+
}
41524184

41534185
if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
41544186
uint64_t Size = ArrTy->getZExtSize();
@@ -4202,14 +4234,17 @@ class InitListTransformer {
42024234
return *(ArgIt++);
42034235

42044236
llvm::SmallVector<Expr *> Inits;
4205-
assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
42064237
Ty = Ty.getDesugaredType(Ctx);
4207-
if (Ty->isVectorType() || Ty->isConstantArrayType()) {
4238+
if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4239+
Ty->isConstantMatrixType()) {
42084240
QualType ElTy;
42094241
uint64_t Size = 0;
42104242
if (auto *ATy = Ty->getAs<VectorType>()) {
42114243
ElTy = ATy->getElementType();
42124244
Size = ATy->getNumElements();
4245+
} else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4246+
ElTy = CMTy->getElementType();
4247+
Size = CMTy->getNumElementsFlattened();
42134248
} else {
42144249
auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
42154250
ElTy = VTy->getElementType();

0 commit comments

Comments
 (0)