Skip to content

Commit 1e3b68c

Browse files
committed
[XPU] support python memory api in XPU
1 parent c0d2715 commit 1e3b68c

File tree

7 files changed

+360
-27
lines changed

7 files changed

+360
-27
lines changed

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,7 @@ All parameter, weight, gradient are variables in Paddle.
29302930

29312931
#ifdef PADDLE_WITH_XPU
29322932
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
2933+
m.def("get_xpu_current_device_id", &platform::GetXPUCurrentDeviceId);
29332934
m.def("xpu_empty_cache", platform::EmptyCache);
29342935
m.def("get_xpu_device_utilization_rate",
29352936
platform::GetXPUDeviceUtilizationRate);

paddle/phi/backends/xpu/xpu_info.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ int GetXPUDeviceUtilizationRate(int dev_id) {
237237
return dev_util.xpu;
238238
}
239239

240-
int GetXPUDeviceTotalMemory(int dev_id) {
240+
int64_t GetXPUDeviceTotalMemory(int dev_id) {
241241
std::call_once(xpuml_init_flag, xpumlInit);
242242
if (dev_id == -1) {
243243
dev_id = GetXPUCurrentDeviceId();
@@ -249,10 +249,10 @@ int GetXPUDeviceTotalMemory(int dev_id) {
249249
xpumlMemory_t dev_mem_info;
250250
PADDLE_ENFORCE_XPUML_SUCCESS(
251251
xpumlDeviceGetMemoryInfo(dev_handle, &dev_mem_info));
252-
return dev_mem_info.totalGlobalMemory / 1024 / 1024; // MB
252+
return dev_mem_info.totalGlobalMemory; // with Byte
253253
}
254254

255-
int GetXPUDeviceUsedMemory(int dev_id) {
255+
int64_t GetXPUDeviceUsedMemory(int dev_id) {
256256
std::call_once(xpuml_init_flag, xpumlInit);
257257
if (dev_id == -1) {
258258
dev_id = GetXPUCurrentDeviceId();
@@ -264,7 +264,7 @@ int GetXPUDeviceUsedMemory(int dev_id) {
264264
xpumlMemory_t dev_mem_info;
265265
PADDLE_ENFORCE_XPUML_SUCCESS(
266266
xpumlDeviceGetMemoryInfo(dev_handle, &dev_mem_info));
267-
return dev_mem_info.usedGlobalMemory / 1024 / 1024; // MB
267+
return dev_mem_info.usedGlobalMemory; // with Byte
268268
}
269269

270270
XPUVersion get_xpu_version(int dev_id) {

paddle/phi/backends/xpu/xpu_info.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ void set_xpu_debug_level(int level);
115115

116116
int get_xpu_max_ptr_size(int dev_id);
117117
int GetXPUDeviceUtilizationRate(int dev_id);
118-
int GetXPUDeviceTotalMemory(int dev_id);
119-
int GetXPUDeviceUsedMemory(int dev_id);
118+
int64_t GetXPUDeviceTotalMemory(int dev_id);
119+
int64_t GetXPUDeviceUsedMemory(int dev_id);
120120

121121
} // namespace xpu
122122
} // namespace backends

paddle/phi/core/platform/device/xpu/xpu_info.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ int GetXPUDeviceUtilizationRate(int dev_id) {
123123
return phi::backends::xpu::GetXPUDeviceUtilizationRate(dev_id);
124124
}
125125

126-
int GetXPUDeviceTotalMemory(int dev_id) {
126+
int64_t GetXPUDeviceTotalMemory(int dev_id) {
127127
return phi::backends::xpu::GetXPUDeviceTotalMemory(dev_id);
128128
}
129129

130-
int GetXPUDeviceUsedMemory(int dev_id) {
130+
int64_t GetXPUDeviceUsedMemory(int dev_id) {
131131
return phi::backends::xpu::GetXPUDeviceUsedMemory(dev_id);
132132
}
133133

paddle/phi/core/platform/device/xpu/xpu_info.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ bool IsXPUMallocRecorded(int dev_id);
9595
void EmptyCache(void);
9696

9797
int GetXPUDeviceUtilizationRate(int dev_id);
98-
int GetXPUDeviceTotalMemory(int dev_id);
99-
int GetXPUDeviceUsedMemory(int dev_id);
98+
int64_t GetXPUDeviceTotalMemory(int dev_id);
99+
int64_t GetXPUDeviceUsedMemory(int dev_id);
100100

101101
} // namespace platform
102102
} // namespace paddle

0 commit comments

Comments
 (0)