@@ -1087,55 +1087,28 @@ struct get_device_info_impl<
10871087 return static_cast <size_t >((std::numeric_limits<int >::max)());
10881088 }
10891089};
1090- template <>
1090+ template <int Dims >
10911091struct get_device_info_impl <
1092- id<1 >, ext::oneapi::experimental::info::device::max_work_groups<1 >> {
1093- static id<1 > get (const DeviceImplPtr &Dev) {
1094- size_t result[3 ];
1092+ id<Dims>, ext::oneapi::experimental::info::device::max_work_groups<Dims>> {
1093+ static id<Dims> get (const DeviceImplPtr &Dev) {
10951094 size_t Limit =
10961095 get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
10971096 max_global_work_groups>::get (Dev);
1098- Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1099- Dev->getHandleRef (),
1100- UrInfoCode<
1101- ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
1102- sizeof (result), &result, nullptr );
1103- return id<1 >(std::min (Limit, result[0 ]));
1104- }
1105- };
11061097
1107- template <>
1108- struct get_device_info_impl <
1109- id<2 >, ext::oneapi::experimental::info::device::max_work_groups<2 >> {
1110- static id<2 > get (const DeviceImplPtr &Dev) {
11111098 size_t result[3 ];
1112- size_t Limit =
1113- get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
1114- max_global_work_groups>::get (Dev);
1115- Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1116- Dev->getHandleRef (),
1117- UrInfoCode<
1118- ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
1119- sizeof (result), &result, nullptr );
1120- return id<2 >(std::min (Limit, result[1 ]), std::min (Limit, result[0 ]));
1121- }
1122- };
1123-
1124- template <>
1125- struct get_device_info_impl <
1126- id<3 >, ext::oneapi::experimental::info::device::max_work_groups<3 >> {
1127- static id<3 > get (const DeviceImplPtr &Dev) {
1128- size_t result[3 ];
1129- size_t Limit =
1130- get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
1131- max_global_work_groups>::get (Dev);
11321099 Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
11331100 Dev->getHandleRef (),
11341101 UrInfoCode<
11351102 ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
11361103 sizeof (result), &result, nullptr );
1137- return id<3 >(std::min (Limit, result[2 ]), std::min (Limit, result[1 ]),
1138- std::min (Limit, result[0 ]));
1104+ static_assert (1 <= Dims && Dims <= 3 );
1105+ if constexpr (Dims == 1 )
1106+ return id<1 >(std::min (Limit, result[0 ]));
1107+ else if constexpr (Dims == 2 )
1108+ return id<2 >(std::min (Limit, result[1 ]), std::min (Limit, result[0 ]));
1109+ else
1110+ return id<3 >(std::min (Limit, result[2 ]), std::min (Limit, result[1 ]),
1111+ std::min (Limit, result[0 ]));
11391112 }
11401113};
11411114
0 commit comments