@@ -33,19 +33,55 @@ namespace py = pybind11;
3333namespace paddle {
3434namespace platform {
3535#ifdef PADDLE_WITH_XPU
36- XPUStream get_current_stream (int device_id) {
37- if (device_id == -1 ) {
38- device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
39- }
36+ phi::XPUStreamHandle* get_current_stream (int device_id) {
4037 auto place = phi::XPUPlace (device_id);
4138 auto *dev_ctx = static_cast <phi::XPUContext *>(
4239 phi::DeviceContextPool::Instance ().Get (place));
43- dev_ctx-> Wait ();
44- return dev_ctx-> stream () ;
40+ auto handle = new phi::XPUStreamHandle ();
41+ return handle ;
4542}
4643
44+ // phi::XPUContext* get_context(int device_id) {
45+ // // int curr_device_id = platform::GetXPUCurrentDeviceId();
46+ // auto place_tmp = phi::XPUPlace(device_id > -1 ? device_id : platform::GetXPUCurrentDeviceId());
47+ // phi::XPUContext *dev_ctx = static_cast<phi::XPUContext *>(
48+ // phi::DeviceContextPool::Instance().Get(place_tmp));
49+
50+ // return dev_ctx;
51+ // }
52+
53+ phi::XPUStreamHandle* set_current_stream (int idx) {
54+ int device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
55+ auto original_stream = get_current_stream (device_id);
56+ auto place = phi::XPUPlace (device_id);
57+ auto *dev_ctx = static_cast <phi::XPUContext *>(
58+ phi::DeviceContextPool::Instance ().Get (place));
59+ dev_ctx->SetCurrentStream (idx);
60+ // return original_stream;
61+ return original_stream;
62+ }
63+
64+
65+ // XPUStream get_stream_by_idx
66+ // ::paddle::phi::XPUCUDAStream get_current_cuda_stream(int device_id) {
67+ // if (device_id == -1) {
68+ // device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
69+ // }
70+ // auto place = phi::XPUPlace(device_id);
71+
72+ // }
73+
4774#endif
4875} // namespace platform
76+
77+
78+ // namespace phi{
79+ // void phi::XPUEventHandle::synchronize() {
80+ // auto *dev_ctx = paddle::platform::get_context();
81+ // dev_ctx->StreamWaitEvent(event_, 0);
82+ // }
83+ // }
84+
4985namespace pybind {
5086void BindXpuStream (py::module *m_ptr) {
5187 auto &m = *m_ptr;
@@ -69,7 +105,7 @@ void BindXpuStream(py::module *m_ptr) {
69105#endif
70106 });
71107 m.def (
72- " _get_current_stream " ,
108+ " _xpu_get_current_stream " ,
73109 [](int device_id) {
74110#ifdef PADDLE_WITH_XPU
75111 if (device_id == -1 ) {
@@ -79,11 +115,16 @@ void BindXpuStream(py::module *m_ptr) {
79115 return platform::get_current_stream (device_id);
80116#else
81117 PADDLE_THROW (
82- common::errors::Unavailable (" Paddle is not compiled with CUDA . "
118+ common::errors::Unavailable (" Paddle is not compiled with XPU . "
83119 " Cannot visit device synchronize." ));
84120#endif
85121 },
86122 py::return_value_policy::reference);
123+ m.def (" _xpu_set_current_stream" ,
124+ [](int stream_id) {
125+ return platform::set_current_stream (stream_id);
126+ },
127+ py::return_value_policy::reference);
87128 m.def (" _device_synchronize" , [](int device_id) {
88129#ifdef PADDLE_WITH_XPU
89130 if (device_id == -1 ) {
@@ -101,11 +142,190 @@ void BindXpuStream(py::module *m_ptr) {
101142 });
102143
103144#ifdef PADDLE_WITH_XPU
104- py::class_<XPUStream >(m, " XPUStream" , R"DOC(
145+ py::class_<phi::XPUStreamHandle >(m, " XPUStream" , R"DOC(
105146 The handle of the CUDA stream.
106147
107148 Parameters:
108- device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
149+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
150+ If device is None or negative integer, device will be the current device.
151+ If device is positive integer, it must less than the device count. Default: None.
152+ priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
153+ If priority is None, the priority is 2(normal). Default: None.
154+
155+ Examples:
156+ .. code-block:: python
157+
158+ >>> # doctest: +REQUIRES(env:GPU)
159+ >>> import paddle
160+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
161+ >>> s2 = paddle.device.xpu.Stream(0, 1)
162+ >>> s3 = paddle.device.xpu.Stream()
163+
164+ )DOC" )
165+ .def (
166+ " __init__" ,
167+ [](phi::XPUStreamHandle &self) {
168+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
169+ // auto place_tmp = phi::XPUPlace(curr_device_id);
170+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
171+ // phi::DeviceContextPool::Instance().Get(place_tmp));
172+ auto *dev_ctx = phi::get_xpu_context ();
173+ // new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
174+ new (&self) phi::XPUStreamHandle ();
175+ // self.idx = dev_ctx->get_idle_stream();
176+ // printf("empty init, idx: %d \n", self.id());
177+ // self.id =
178+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
179+ // if (dev_ctx->stream() == nullptr) {
180+ // dev_ctx->SetStream(self);
181+ // }
182+ }
183+ )
184+ .def_property_readonly (
185+ " xpu_stream" ,
186+ [](phi::XPUStreamHandle &self) {
187+ // printf("call pybind xpu_stream()\n");
188+ // return self.raw_stream();
189+ return self;
190+ }
191+ )
192+ .def (
193+ " wait_stream" ,
194+ [](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other){
195+ // XPUEvent event;
196+ auto *dev_ctx = phi::get_xpu_context ();
197+ dev_ctx->StreamWaitStreamInPool (self.id (), other.id ());
198+ // dev_ctx->StreamWaitStream()
199+ }
200+ )
201+ // .def(
202+ // "__init__",
203+ // [](XPUStream &self, phi::XPUPlace *place) {
204+ // // int device_count = platform::GetXPUDeviceCount();
205+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
206+ // phi::DeviceContextPool::Instance().Get(place_tmp));
207+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
208+ // if (place == nullptr) {
209+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
210+ // auto place_tmp = phi::XPUPlace(curr_device_id);
211+ // }
212+
213+ // if (dev_ctx->stream() == nullptr) {
214+ // dev_ctx->SetStream(self);
215+ // }
216+ // }
217+ // )
218+ .def (
219+ " __init__" ,
220+ [](phi::XPUStreamHandle &self, int device) {
221+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
222+ auto place_tmp = phi::XPUPlace (device);
223+ auto *dev_ctx = static_cast <phi::XPUContext *>(
224+ phi::DeviceContextPool::Instance ().Get (place_tmp));
225+ // auto *dev_ctx = platform::get_context();
226+ new (&self) phi::XPUStreamHandle ();
227+ // new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
228+ // new (&self) phi::XPUStreamHandle;
229+ // self.idx = dev_ctx->get_idle_stream();
230+ // printf("device init, idx: %d \n", self.id());
231+
232+ // int device_count = platform::GetXPUDeviceCount();
233+ // if (device < 0) {
234+ // device = platform::GetXPUCurrentDeviceId();
235+ // }
236+ // if (device >= device_count) {
237+ // PADDLE_THROW(common::errors::InvalidArgument(
238+ // "The device id must be inside [0, %d), but input device=%d.",
239+ // device_count,
240+ // device));
241+ // }
242+ // auto place_tmp = phi::XPUPlace(device);
243+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
244+ // phi::DeviceContextPool::Instance().Get(place_tmp));
245+ // new (&self) XPUStream;
246+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
247+ // if (dev_ctx->stream() == nullptr) {
248+ // dev_ctx->SetStream(self);
249+ // }
250+ },
251+ py::arg (" device" ) = -1
252+ )
253+ .def_property_readonly (
254+ " place" ,
255+ [](phi::XPUStreamHandle &self) { return phi::XPUPlace (platform::GetXPUCurrentDeviceId ()); })
256+ .def_property_readonly (
257+ " idx" ,
258+ [](phi::XPUStreamHandle &self) { return self.id (); }
259+ )
260+ // .def(
261+ // "synchronize",
262+ // [](XPUStream &self) { xpu_wait(self); },
263+ // R"DOC(
264+ // Waits for stream tasks to complete.
265+
266+ // Examples:
267+ // .. code-block:: python
268+
269+ // >>> # doctest: +REQUIRES(env:GPU)
270+ // >>> import paddle
271+ // >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
272+ // >>> s.synchronize()
273+
274+ // )DOC"
275+ // )
276+ ;
277+ py::class_<phi::XPUEventHandle>(m, " XPUEvent" , R"DOC(
278+ The handle of the CUDA event.
279+
280+ Parameters:
281+ enable_timing(bool, optional): Whether the event will measure time. Default: False.
282+ blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
283+ interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
284+
285+ Examples:
286+ .. code-block:: python
287+
288+ >>> # doctest: +REQUIRES(env:GPU)
289+ >>> import paddle
290+ >>> event = paddle.device.cuda.Event()
291+
292+ )DOC" )
293+ .def (
294+ " __init__" ,
295+ [](phi::XPUEventHandle &self) {
296+ new (&self) phi::XPUEventHandle ();
297+ }
298+ )
299+ .def (
300+ " record" ,
301+ [](phi::XPUEventHandle &self, phi::XPUStreamHandle* stream) {
302+ auto *dev_ctx = phi::get_xpu_context ();
303+ XPUStream raw_stream = dev_ctx->get_stream_from_pool (stream->id ());
304+ int r = xpu_event_record (self.get_event (), raw_stream);
305+ PADDLE_ENFORCE_XRE_SUCCESS (r);
306+
307+ },
308+ py::arg (" stream" ) = nullptr
309+ )
310+ .def (
311+ " query" ,
312+ [](phi::XPUEventHandle &self) {return xpu_event_query (self.get_event ());}
313+ )
314+ .def (
315+ " synchronize" ,
316+ [](phi::XPUEventHandle &self) {
317+ auto *dev_ctx = phi::get_xpu_context ();
318+ dev_ctx->StreamWaitEvent (self.get_event (), 0 );
319+ // self.synchronize();
320+ }
321+ );
322+
323+
324+ py::class_<phi::XPUCUDAStream>(m, " XPUCUDAStream" , R"DOC(
325+ The handle of the XPU stream.
326+
327+ Parameters:
328+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
109329 If device is None or negative integer, device will be the current device.
110330 If device is positive integer, it must less than the device count. Default: None.
111331 priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -116,14 +336,14 @@ void BindXpuStream(py::module *m_ptr) {
116336
117337 >>> # doctest: +REQUIRES(env:GPU)
118338 >>> import paddle
119- >>> s1 = paddle.device.cuda .Stream(paddle.CUDAPlace (0), 1)
120- >>> s2 = paddle.device.cuda .Stream(0, 1)
121- >>> s3 = paddle.device.cuda .Stream()
339+ >>> s1 = paddle.device.xpu .Stream(paddle.XPUPlace (0), 1)
340+ >>> s2 = paddle.device.xpu .Stream(0, 1)
341+ >>> s3 = paddle.device.xpu .Stream()
122342
123343 )DOC" )
124344 .def (
125345 " synchronize" ,
126- [](XPUStream &self) { xpu_wait ( self); },
346+ [](phi::XPUCUDAStream &self) { self. Synchronize ( ); },
127347 R"DOC(
128348 Waits for stream tasks to complete.
129349
@@ -135,7 +355,26 @@ void BindXpuStream(py::module *m_ptr) {
135355 >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136356 >>> s.synchronize()
137357
138- )DOC" );
358+ )DOC"
359+ )
360+ .def (
361+ " __init__" ,
362+ [](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
363+ if (priority != 1 && priority != 2 ) {
364+ PADDLE_THROW (common::errors::InvalidArgument (
365+ " Priority should be 1(high) or 2(normal) " ));
366+ }
367+ auto stream_flag = phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking ;
368+ if (place == nullptr ) {
369+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
370+ auto place_tmp = phi::XPUPlace (curr_device_id);
371+ new (&self) phi::XPUCUDAStream (place_tmp, priority - 2 , stream_flag);
372+ } else {
373+ new (&self) phi::XPUCUDAStream (*place, priority - 2 , stream_flag);
374+ }
375+
376+ }
377+ );
139378#endif
140379}
141380} // namespace pybind
0 commit comments