@@ -797,14 +797,24 @@ pi_result _pi_device::initialize(int SubSubDeviceOrdinal,
797797
798798 ZeDeviceMemoryProperties.Compute =
799799 [ZeDevice](
800- std::vector<ZeStruct<ze_device_memory_properties_t >> &Properties) {
800+ std::pair<std::vector<ZeStruct<ze_device_memory_properties_t >>,
801+ std::vector<ZeStruct<ze_device_memory_ext_properties_t >>>
802+ &Properties) {
801803 uint32_t Count = 0 ;
802804 ZE_CALL_NOCHECK (zeDeviceGetMemoryProperties,
803805 (ZeDevice, &Count, nullptr ));
804806
805- Properties.resize (Count);
807+ auto &PropertiesVector = Properties.first ;
808+ auto &PropertiesExtVector = Properties.second ;
809+
810+ PropertiesVector.resize (Count);
811+ PropertiesExtVector.resize (Count);
812+ // Request for extended memory properties be read in
813+ for (uint32_t I = 0 ; I < Count; ++I)
814+ PropertiesVector[I].pNext = (void *)&PropertiesExtVector[I];
815+
806816 ZE_CALL_NOCHECK (zeDeviceGetMemoryProperties,
807- (ZeDevice, &Count, Properties .data ()));
817+ (ZeDevice, &Count, PropertiesVector .data ()));
808818 };
809819
810820 ZeDeviceMemoryAccessProperties.Compute =
@@ -2962,9 +2972,9 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
29622972 return ReturnValue (pi_uint64{Device->ZeDeviceProperties ->maxMemAllocSize });
29632973 case PI_DEVICE_INFO_GLOBAL_MEM_SIZE: {
29642974 uint64_t GlobalMemSize = 0 ;
2965- for (uint32_t I = 0 ; I < Device-> ZeDeviceMemoryProperties -> size (); I++) {
2966- GlobalMemSize +=
2967- (*Device-> ZeDeviceMemoryProperties . operator ->())[I]. totalSize ;
2975+ for (const auto &ZeDeviceMemoryExtProperty :
2976+ Device-> ZeDeviceMemoryProperties -> second ) {
2977+ GlobalMemSize += ZeDeviceMemoryExtProperty. physicalSize ;
29682978 }
29692979 return ReturnValue (pi_uint64{GlobalMemSize});
29702980 }
@@ -3337,21 +3347,32 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
33373347 // Only report device memory which zeMemAllocDevice can allocate from.
33383348 // Currently this is only the one enumerated with ordinal 0.
33393349 uint64_t FreeMemory = 0 ;
3340- uint32_t MemCount = 1 ;
3341- zes_mem_handle_t ZesMemHandle;
3342- ZE_CALL (zesDeviceEnumMemoryModules, (ZeDevice, &MemCount, &ZesMemHandle));
3350+ uint32_t MemCount = 0 ;
3351+ ZE_CALL (zesDeviceEnumMemoryModules, (ZeDevice, &MemCount, nullptr ));
33433352 if (MemCount != 0 ) {
3344- ZesStruct<zes_mem_properties_t > ZeMemProperties;
3345- ZE_CALL (zesMemoryGetProperties, (ZesMemHandle, &ZeMemProperties));
3346- ZesStruct<zes_mem_state_t > ZeMemState;
3347- ZE_CALL (zesMemoryGetState, (ZesMemHandle, &ZeMemState));
3348- FreeMemory += ZeMemState.free ;
3353+ std::vector<zes_mem_handle_t > ZesMemHandles (MemCount);
3354+ ZE_CALL (zesDeviceEnumMemoryModules,
3355+ (ZeDevice, &MemCount, ZesMemHandles.data ()));
3356+ for (auto &ZesMemHandle : ZesMemHandles) {
3357+ ZesStruct<zes_mem_properties_t > ZesMemProperties;
3358+ ZE_CALL (zesMemoryGetProperties, (ZesMemHandle, &ZesMemProperties));
3359+ // For root-device report memory from all memory modules since that
3360+ // is what totally available in the default implicit scaling mode.
3361+ // For sub-devices only report memory local to them.
3362+ if (!Device->isSubDevice () || Device->ZeDeviceProperties ->subdeviceId ==
3363+ ZesMemProperties.subdeviceId ) {
3364+
3365+ ZesStruct<zes_mem_state_t > ZesMemState;
3366+ ZE_CALL (zesMemoryGetState, (ZesMemHandle, &ZesMemState));
3367+ FreeMemory += ZesMemState.free ;
3368+ }
3369+ }
33493370 }
33503371 return ReturnValue (FreeMemory);
33513372 }
33523373 case PI_EXT_INTEL_DEVICE_INFO_MEMORY_CLOCK_RATE: {
33533374 // If there are not any memory modules then return 0.
3354- if (Device->ZeDeviceMemoryProperties ->empty ())
3375+ if (Device->ZeDeviceMemoryProperties ->first . empty ())
33553376 return ReturnValue (pi_uint32{0 });
33563377
33573378 // If there are multiple memory modules on the device then we have to report
@@ -3361,13 +3382,13 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
33613382 return A.maxClockRate < B.maxClockRate ;
33623383 };
33633384 auto MinIt =
3364- std::min_element (Device->ZeDeviceMemoryProperties ->begin (),
3365- Device->ZeDeviceMemoryProperties ->end (), Comp);
3385+ std::min_element (Device->ZeDeviceMemoryProperties ->first . begin (),
3386+ Device->ZeDeviceMemoryProperties ->first . end (), Comp);
33663387 return ReturnValue (pi_uint32{MinIt->maxClockRate });
33673388 }
33683389 case PI_EXT_INTEL_DEVICE_INFO_MEMORY_BUS_WIDTH: {
33693390 // If there are not any memory modules then return 0.
3370- if (Device->ZeDeviceMemoryProperties ->empty ())
3391+ if (Device->ZeDeviceMemoryProperties ->first . empty ())
33713392 return ReturnValue (pi_uint32{0 });
33723393
33733394 // If there are multiple memory modules on the device then we have to report
@@ -3377,8 +3398,8 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
33773398 return A.maxBusWidth < B.maxBusWidth ;
33783399 };
33793400 auto MinIt =
3380- std::min_element (Device->ZeDeviceMemoryProperties ->begin (),
3381- Device->ZeDeviceMemoryProperties ->end (), Comp);
3401+ std::min_element (Device->ZeDeviceMemoryProperties ->first . begin (),
3402+ Device->ZeDeviceMemoryProperties ->first . end (), Comp);
33823403 return ReturnValue (pi_uint32{MinIt->maxBusWidth });
33833404 }
33843405 case PI_EXT_INTEL_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES: {
0 commit comments