Skip to content

Commit b4688e3

Browse files
committed
[XPU] support python streams api for xpu
1 parent cf30d11 commit b4688e3

File tree

7 files changed

+510
-25
lines changed

7 files changed

+510
-25
lines changed

paddle/fluid/pybind/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
902902
const auto &device_id =
903903
paddle::platform::GetXPUCurrentDeviceId();
904904
auto stream = paddle::platform::get_current_stream(device_id);
905-
xpu_wait(stream);
905+
xpu_wait(stream->raw_stream());
906906
int type_idx = static_cast<int>(self.type());
907907
size_t data_size = self.numel() *
908908
framework::SizeOfType(

paddle/fluid/pybind/xpu_streams_py.cc

Lines changed: 254 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,55 @@ namespace py = pybind11;
3333
namespace paddle {
3434
namespace platform {
3535
#ifdef PADDLE_WITH_XPU
36-
XPUStream get_current_stream(int device_id) {
37-
if (device_id == -1) {
38-
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
39-
}
36+
phi::XPUStreamHandle* get_current_stream(int device_id) {
4037
auto place = phi::XPUPlace(device_id);
4138
auto *dev_ctx = static_cast<phi::XPUContext *>(
4239
phi::DeviceContextPool::Instance().Get(place));
43-
dev_ctx->Wait();
44-
return dev_ctx->stream();
40+
auto handle = new phi::XPUStreamHandle();
41+
return handle;
4542
}
4643

44+
// phi::XPUContext* get_context(int device_id) {
45+
// // int curr_device_id = platform::GetXPUCurrentDeviceId();
46+
// auto place_tmp = phi::XPUPlace(device_id > -1 ? device_id : platform::GetXPUCurrentDeviceId());
47+
// phi::XPUContext *dev_ctx = static_cast<phi::XPUContext *>(
48+
// phi::DeviceContextPool::Instance().Get(place_tmp));
49+
50+
// return dev_ctx;
51+
// }
52+
53+
phi::XPUStreamHandle* set_current_stream(int idx) {
54+
int device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
55+
auto original_stream = get_current_stream(device_id);
56+
auto place = phi::XPUPlace(device_id);
57+
auto *dev_ctx = static_cast<phi::XPUContext *>(
58+
phi::DeviceContextPool::Instance().Get(place));
59+
dev_ctx->SetCurrentStream(idx);
60+
// return original_stream;
61+
return original_stream;
62+
}
63+
64+
65+
// XPUStream get_stream_by_idx
66+
// ::paddle::phi::XPUCUDAStream get_current_cuda_stream(int device_id) {
67+
// if (device_id == -1) {
68+
// device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
69+
// }
70+
// auto place = phi::XPUPlace(device_id);
71+
72+
// }
73+
4774
#endif
4875
} // namespace platform
76+
77+
78+
// namespace phi{
79+
// void phi::XPUEventHandle::synchronize() {
80+
// auto *dev_ctx = paddle::platform::get_context();
81+
// dev_ctx->StreamWaitEvent(event_, 0);
82+
// }
83+
// }
84+
4985
namespace pybind {
5086
void BindXpuStream(py::module *m_ptr) {
5187
auto &m = *m_ptr;
@@ -69,7 +105,7 @@ void BindXpuStream(py::module *m_ptr) {
69105
#endif
70106
});
71107
m.def(
72-
"_get_current_stream",
108+
"_xpu_get_current_stream",
73109
[](int device_id) {
74110
#ifdef PADDLE_WITH_XPU
75111
if (device_id == -1) {
@@ -79,11 +115,16 @@ void BindXpuStream(py::module *m_ptr) {
79115
return platform::get_current_stream(device_id);
80116
#else
81117
PADDLE_THROW(
82-
common::errors::Unavailable("Paddle is not compiled with CUDA. "
118+
common::errors::Unavailable("Paddle is not compiled with XPU. "
83119
"Cannot visit device synchronize."));
84120
#endif
85121
},
86122
py::return_value_policy::reference);
123+
m.def("_xpu_set_current_stream",
124+
[](int stream_id) {
125+
return platform::set_current_stream(stream_id);
126+
},
127+
py::return_value_policy::reference);
87128
m.def("_device_synchronize", [](int device_id) {
88129
#ifdef PADDLE_WITH_XPU
89130
if (device_id == -1) {
@@ -101,11 +142,190 @@ void BindXpuStream(py::module *m_ptr) {
101142
});
102143

103144
#ifdef PADDLE_WITH_XPU
104-
py::class_<XPUStream>(m, "XPUStream", R"DOC(
145+
py::class_<phi::XPUStreamHandle>(m, "XPUStream", R"DOC(
105146
The handle of the CUDA stream.
106147
107148
Parameters:
108-
device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
149+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
150+
If device is None or negative integer, device will be the current device.
151+
If device is positive integer, it must less than the device count. Default: None.
152+
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
153+
If priority is None, the priority is 2(normal). Default: None.
154+
155+
Examples:
156+
.. code-block:: python
157+
158+
>>> # doctest: +REQUIRES(env:GPU)
159+
>>> import paddle
160+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
161+
>>> s2 = paddle.device.xpu.Stream(0, 1)
162+
>>> s3 = paddle.device.xpu.Stream()
163+
164+
)DOC")
165+
.def(
166+
"__init__",
167+
[](phi::XPUStreamHandle &self) {
168+
// int curr_device_id = platform::GetXPUCurrentDeviceId();
169+
// auto place_tmp = phi::XPUPlace(curr_device_id);
170+
// auto *dev_ctx = static_cast<phi::XPUContext *>(
171+
// phi::DeviceContextPool::Instance().Get(place_tmp));
172+
auto *dev_ctx = phi::get_xpu_context();
173+
// new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
174+
new (&self) phi::XPUStreamHandle();
175+
// self.idx = dev_ctx->get_idle_stream();
176+
// printf("empty init, idx: %d \n", self.id());
177+
// self.id =
178+
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
179+
// if (dev_ctx->stream() == nullptr) {
180+
// dev_ctx->SetStream(self);
181+
// }
182+
}
183+
)
184+
.def_property_readonly(
185+
"xpu_stream",
186+
[](phi::XPUStreamHandle &self) {
187+
// printf("call pybind xpu_stream()\n");
188+
// return self.raw_stream();
189+
return self;
190+
}
191+
)
192+
.def(
193+
"wait_stream",
194+
[](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other){
195+
// XPUEvent event;
196+
auto *dev_ctx = phi::get_xpu_context();
197+
dev_ctx->StreamWaitStreamInPool(self.id(), other.id());
198+
// dev_ctx->StreamWaitStream()
199+
}
200+
)
201+
// .def(
202+
// "__init__",
203+
// [](XPUStream &self, phi::XPUPlace *place) {
204+
// // int device_count = platform::GetXPUDeviceCount();
205+
// auto *dev_ctx = static_cast<phi::XPUContext *>(
206+
// phi::DeviceContextPool::Instance().Get(place_tmp));
207+
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
208+
// if (place == nullptr) {
209+
// int curr_device_id = platform::GetXPUCurrentDeviceId();
210+
// auto place_tmp = phi::XPUPlace(curr_device_id);
211+
// }
212+
213+
// if (dev_ctx->stream() == nullptr) {
214+
// dev_ctx->SetStream(self);
215+
// }
216+
// }
217+
// )
218+
.def(
219+
"__init__",
220+
[](phi::XPUStreamHandle &self, int device) {
221+
// int curr_device_id = platform::GetXPUCurrentDeviceId();
222+
auto place_tmp = phi::XPUPlace(device);
223+
auto *dev_ctx = static_cast<phi::XPUContext *>(
224+
phi::DeviceContextPool::Instance().Get(place_tmp));
225+
// auto *dev_ctx = platform::get_context();
226+
new (&self) phi::XPUStreamHandle();
227+
// new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
228+
// new (&self) phi::XPUStreamHandle;
229+
// self.idx = dev_ctx->get_idle_stream();
230+
// printf("device init, idx: %d \n", self.id());
231+
232+
// int device_count = platform::GetXPUDeviceCount();
233+
// if (device < 0) {
234+
// device = platform::GetXPUCurrentDeviceId();
235+
// }
236+
// if (device >= device_count) {
237+
// PADDLE_THROW(common::errors::InvalidArgument(
238+
// "The device id must be inside [0, %d), but input device=%d.",
239+
// device_count,
240+
// device));
241+
// }
242+
// auto place_tmp = phi::XPUPlace(device);
243+
// auto *dev_ctx = static_cast<phi::XPUContext *>(
244+
// phi::DeviceContextPool::Instance().Get(place_tmp));
245+
// new (&self) XPUStream;
246+
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
247+
// if (dev_ctx->stream() == nullptr) {
248+
// dev_ctx->SetStream(self);
249+
// }
250+
},
251+
py::arg("device") = -1
252+
)
253+
.def_property_readonly(
254+
"place",
255+
[](phi::XPUStreamHandle &self) { return phi::XPUPlace(platform::GetXPUCurrentDeviceId()); })
256+
.def_property_readonly(
257+
"idx",
258+
[](phi::XPUStreamHandle &self) { return self.id(); }
259+
)
260+
// .def(
261+
// "synchronize",
262+
// [](XPUStream &self) { xpu_wait(self); },
263+
// R"DOC(
264+
// Waits for stream tasks to complete.
265+
266+
// Examples:
267+
// .. code-block:: python
268+
269+
// >>> # doctest: +REQUIRES(env:GPU)
270+
// >>> import paddle
271+
// >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
272+
// >>> s.synchronize()
273+
274+
// )DOC"
275+
// )
276+
;
277+
py::class_<phi::XPUEventHandle>(m, "XPUEvent", R"DOC(
278+
The handle of the CUDA event.
279+
280+
Parameters:
281+
enable_timing(bool, optional): Whether the event will measure time. Default: False.
282+
blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
283+
interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
284+
285+
Examples:
286+
.. code-block:: python
287+
288+
>>> # doctest: +REQUIRES(env:GPU)
289+
>>> import paddle
290+
>>> event = paddle.device.cuda.Event()
291+
292+
)DOC")
293+
.def(
294+
"__init__",
295+
[](phi::XPUEventHandle &self) {
296+
new (&self) phi::XPUEventHandle();
297+
}
298+
)
299+
.def(
300+
"record",
301+
[](phi::XPUEventHandle &self, phi::XPUStreamHandle* stream) {
302+
auto *dev_ctx = phi::get_xpu_context();
303+
XPUStream raw_stream = dev_ctx->get_stream_from_pool(stream->id());
304+
int r = xpu_event_record(self.get_event(), raw_stream);
305+
PADDLE_ENFORCE_XRE_SUCCESS(r);
306+
307+
},
308+
py::arg("stream") = nullptr
309+
)
310+
.def(
311+
"query",
312+
[](phi::XPUEventHandle &self) {return xpu_event_query(self.get_event());}
313+
)
314+
.def(
315+
"synchronize",
316+
[](phi::XPUEventHandle &self) {
317+
auto *dev_ctx = phi::get_xpu_context();
318+
dev_ctx->StreamWaitEvent(self.get_event(), 0);
319+
// self.synchronize();
320+
}
321+
);
322+
323+
324+
py::class_<phi::XPUCUDAStream>(m, "XPUCUDAStream", R"DOC(
325+
The handle of the XPU stream.
326+
327+
Parameters:
328+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
109329
If device is None or negative integer, device will be the current device.
110330
If device is positive integer, it must less than the device count. Default: None.
111331
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -116,14 +336,14 @@ void BindXpuStream(py::module *m_ptr) {
116336
117337
>>> # doctest: +REQUIRES(env:GPU)
118338
>>> import paddle
119-
>>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
120-
>>> s2 = paddle.device.cuda.Stream(0, 1)
121-
>>> s3 = paddle.device.cuda.Stream()
339+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
340+
>>> s2 = paddle.device.xpu.Stream(0, 1)
341+
>>> s3 = paddle.device.xpu.Stream()
122342
123343
)DOC")
124344
.def(
125345
"synchronize",
126-
[](XPUStream &self) { xpu_wait(self); },
346+
[](phi::XPUCUDAStream &self) { self.Synchronize(); },
127347
R"DOC(
128348
Waits for stream tasks to complete.
129349
@@ -135,7 +355,26 @@ void BindXpuStream(py::module *m_ptr) {
135355
>>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136356
>>> s.synchronize()
137357
138-
)DOC");
358+
)DOC"
359+
)
360+
.def(
361+
"__init__",
362+
[](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
363+
if (priority != 1 && priority != 2) {
364+
PADDLE_THROW(common::errors::InvalidArgument(
365+
"Priority should be 1(high) or 2(normal) "));
366+
}
367+
auto stream_flag = phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking;
368+
if (place == nullptr) {
369+
int curr_device_id = platform::GetXPUCurrentDeviceId();
370+
auto place_tmp = phi::XPUPlace(curr_device_id);
371+
new (&self) phi::XPUCUDAStream(place_tmp, priority - 2, stream_flag);
372+
} else {
373+
new (&self) phi::XPUCUDAStream(*place, priority - 2, stream_flag);
374+
}
375+
376+
}
377+
);
139378
#endif
140379
}
141380
} // namespace pybind

paddle/fluid/pybind/xpu_streams_py.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,33 @@
1919

2020
#ifdef PADDLE_WITH_XPU
2121
#include "paddle/phi/core/xpu_cuda_stream.h"
22+
#include "paddle/phi/backends/xpu/xpu_context.h"
2223
#include "xpu/runtime.h"
2324
#include "xpu/runtime_ex.h"
25+
namespace phi {
26+
// class XPUStreamHandle {
27+
// public:
28+
// int idx;
29+
// XPUStreamHandle(int stream_id): idx(stream_id) {};
30+
// };
31+
32+
// class XPUEventHandle {
33+
// public:
34+
// XPUEventHandle() {
35+
// int r = xpu_event_create(&event_);
36+
// PADDLE_ENFORCE_XRE_SUCCESS(r); }
37+
// void record(XPUStream stream_) {
38+
// PADDLE_ENFORCE_XRE_SUCCESS(xpu_event_record(event_, stream_));
39+
// }
40+
41+
// XPUEvent event() const { return event_; }
42+
43+
// private:
44+
// XPUEvent event_;
45+
// };
46+
47+
}
48+
2449
#else
2550
namespace phi {
2651
class XPUCUDAStream {};
@@ -32,7 +57,9 @@ namespace py = pybind11;
3257
namespace paddle {
3358
namespace platform {
3459
#ifdef PADDLE_WITH_XPU
35-
XPUStream get_current_stream(int device_id = -1);
60+
phi::XPUStreamHandle* get_current_stream(int device_id = -1);
61+
phi::XPUStreamHandle* set_current_stream(int idx) ;
62+
// phi::XPUContext* get_context(int device_id = -1);
3663
#endif
3764
} // namespace platform
3865
namespace pybind {

0 commit comments

Comments
 (0)