@@ -74,52 +74,6 @@ static Optional<Type> getArrayTy(MLIRContext &context, unsigned dimNum,
7474// Type conversion
7575// ===----------------------------------------------------------------------===//
7676
77- // / Converts SYCL array type to LLVM type.
78- static Optional<Type> convertArrayType (sycl::ArrayType type,
79- LLVMTypeConverter &converter) {
80- assert (type.getBody ().size () == 1 &&
81- " Expecting SYCL array body to have size 1" );
82- assert (type.getBody ()[0 ].isa <MemRefType>() &&
83- " Expecting SYCL array body entry to be MemRefType" );
84- assert (type.getBody ()[0 ].cast <MemRefType>().getElementType () ==
85- converter.getIndexType () &&
86- " Expecting SYCL array body entry element type to be the index type" );
87- return getArrayTy (converter.getContext (), type.getDimension (),
88- converter.getIndexType ());
89- }
90-
91- // / Converts SYCL range or id type to LLVM type, given \p dimNum - number of
92- // / dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
93- // / converter.
94- static Optional<Type> convertRangeOrIDTy (unsigned dimNum, StringRef name,
95- LLVMTypeConverter &converter) {
96- auto convertedTy = LLVM::LLVMStructType::getIdentified (
97- &converter.getContext (), name.str () + " ." + std::to_string (dimNum));
98- if (!convertedTy.isInitialized ()) {
99- auto arrayTy =
100- getArrayTy (converter.getContext (), dimNum, converter.getIndexType ());
101- if (!arrayTy.hasValue ())
102- return llvm::None;
103- if (failed (convertedTy.setBody (arrayTy.getValue (), /* isPacked=*/ false )))
104- return llvm::None;
105- }
106- return convertedTy;
107- }
108-
109- // / Converts SYCL id type to LLVM type.
110- static Optional<Type> convertIDType (sycl::IDType type,
111- LLVMTypeConverter &converter) {
112- return convertRangeOrIDTy (type.getDimension (), " class.cl::sycl::id" ,
113- converter);
114- }
115-
116- // / Converts SYCL range type to LLVM type.
117- static Optional<Type> convertRangeType (sycl::RangeType type,
118- LLVMTypeConverter &converter) {
119- return convertRangeOrIDTy (type.getDimension (), " class.cl::sycl::range" ,
120- converter);
121- }
122-
12377// / Create a LLVM struct type with name \p name, and the converted \p body as
12478// / the body.
12579static Optional<Type> convertBodyType (StringRef name,
@@ -172,6 +126,53 @@ static Optional<Type> convertAccessorType(sycl::AccessorType type,
172126 return convertedTy;
173127}
174128
129+ // / Converts SYCL array type to LLVM type.
130+ static Optional<Type> convertArrayType (sycl::ArrayType type,
131+ LLVMTypeConverter &converter) {
132+ assert (type.getBody ().size () == 1 &&
133+ " Expecting SYCL array body to have size 1" );
134+ assert (type.getBody ()[0 ].isa <MemRefType>() &&
135+ " Expecting SYCL array body entry to be MemRefType" );
136+ assert (type.getBody ()[0 ].cast <MemRefType>().getElementType () ==
137+ converter.getIndexType () &&
138+ " Expecting SYCL array body entry element type to be the index type" );
139+ return getArrayTy (converter.getContext (), type.getDimension (),
140+ converter.getIndexType ());
141+ }
142+
143+ // / Converts SYCL group type to LLVM type.
144+ static Optional<Type> convertGroupType (sycl::GroupType type,
145+ LLVMTypeConverter &converter) {
146+ return convertBodyType (" class.cl::sycl::group." +
147+ std::to_string (type.getDimension ()),
148+ type.getBody (), converter);
149+ }
150+
151+ // / Converts SYCL range or id type to LLVM type, given \p dimNum - number of
152+ // / dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
153+ // / converter.
154+ static Optional<Type> convertRangeOrIDTy (unsigned dimNum, StringRef name,
155+ LLVMTypeConverter &converter) {
156+ auto convertedTy = LLVM::LLVMStructType::getIdentified (
157+ &converter.getContext (), name.str () + " ." + std::to_string (dimNum));
158+ if (!convertedTy.isInitialized ()) {
159+ auto arrayTy =
160+ getArrayTy (converter.getContext (), dimNum, converter.getIndexType ());
161+ if (!arrayTy.hasValue ())
162+ return llvm::None;
163+ if (failed (convertedTy.setBody (arrayTy.getValue (), /* isPacked=*/ false )))
164+ return llvm::None;
165+ }
166+ return convertedTy;
167+ }
168+
169+ // / Converts SYCL id type to LLVM type.
170+ static Optional<Type> convertIDType (sycl::IDType type,
171+ LLVMTypeConverter &converter) {
172+ return convertRangeOrIDTy (type.getDimension (), " class.cl::sycl::id" ,
173+ converter);
174+ }
175+
175176// / Converts SYCL item base type to LLVM type.
176177static Optional<Type> convertItemBaseType (sycl::ItemBaseType type,
177178 LLVMTypeConverter &converter) {
@@ -190,6 +191,21 @@ static Optional<Type> convertItemType(sycl::ItemType type,
190191 type.getBody (), converter);
191192}
192193
194+ // / Converts SYCL nd item type to LLVM type.
195+ static Optional<Type> convertNdItemType (sycl::NdItemType type,
196+ LLVMTypeConverter &converter) {
197+ return convertBodyType (" class.cl::sycl::nd_item." +
198+ std::to_string (type.getDimension ()),
199+ type.getBody (), converter);
200+ }
201+
202+ // / Converts SYCL range type to LLVM type.
203+ static Optional<Type> convertRangeType (sycl::RangeType type,
204+ LLVMTypeConverter &converter) {
205+ return convertRangeOrIDTy (type.getDimension (), " class.cl::sycl::range" ,
206+ converter);
207+ }
208+
193209// ===----------------------------------------------------------------------===//
194210// ConstructorPattern - Converts `sycl.constructor` to LLVM.
195211// ===----------------------------------------------------------------------===//
@@ -263,8 +279,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
263279 return convertArrayType (type, typeConverter);
264280 });
265281 typeConverter.addConversion ([&](sycl::GroupType type) {
266- llvm_unreachable (" SYCLToLLVM - sycl::GroupType not handle (yet)" );
267- return llvm::None;
282+ return convertGroupType (type, typeConverter);
268283 });
269284 typeConverter.addConversion (
270285 [&](sycl::IDType type) { return convertIDType (type, typeConverter); });
@@ -275,8 +290,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
275290 return convertItemType (type, typeConverter);
276291 });
277292 typeConverter.addConversion ([&](sycl::NdItemType type) {
278- llvm_unreachable (" SYCLToLLVM - sycl::NdItemType not handle (yet)" );
279- return llvm::None;
293+ return convertNdItemType (type, typeConverter);
280294 });
281295 typeConverter.addConversion ([&](sycl::RangeType type) {
282296 return convertRangeType (type, typeConverter);
0 commit comments