@@ -33,19 +33,27 @@ namespace py = pybind11;
3333namespace paddle {
3434namespace platform {
3535#ifdef PADDLE_WITH_XPU
36- XPUStream get_current_stream (int device_id) {
37- if (device_id == -1 ) {
38- device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
39- }
36+ phi::XPUStreamHandle *get_current_stream (int device_id) {
4037 auto place = phi::XPUPlace (device_id);
4138 auto *dev_ctx = static_cast <phi::XPUContext *>(
4239 phi::DeviceContextPool::Instance ().Get (place));
4340 dev_ctx->Wait ();
44- return dev_ctx->stream ();
41+ return dev_ctx->get_current_stream_handle ();
42+ }
43+
44+ phi::XPUStreamHandle *set_current_stream (int idx) {
45+ int device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
46+ auto original_stream = get_current_stream (device_id);
47+ auto place = phi::XPUPlace (device_id);
48+ auto *dev_ctx = static_cast <phi::XPUContext *>(
49+ phi::DeviceContextPool::Instance ().Get (place));
50+ dev_ctx->SetCurrentStream (idx);
51+ return original_stream;
4552}
4653
4754#endif
4855} // namespace platform
56+
4957namespace pybind {
5058void BindXpuStream (py::module *m_ptr) {
5159 auto &m = *m_ptr;
@@ -69,7 +77,7 @@ void BindXpuStream(py::module *m_ptr) {
6977#endif
7078 });
7179 m.def (
72- " _get_current_stream " ,
80+ " _xpu_get_current_stream " ,
7381 [](int device_id) {
7482#ifdef PADDLE_WITH_XPU
7583 if (device_id == -1 ) {
@@ -79,7 +87,19 @@ void BindXpuStream(py::module *m_ptr) {
7987 return platform::get_current_stream (device_id);
8088#else
8189 PADDLE_THROW (
82- common::errors::Unavailable (" Paddle is not compiled with CUDA. "
90+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
91+ " Cannot visit device synchronize." ));
92+ #endif
93+ },
94+ py::return_value_policy::reference);
95+ m.def (
96+ " _xpu_set_current_stream" ,
97+ [](int stream_id) {
98+ #ifdef PADDLE_WITH_XPU
99+ return platform::set_current_stream (stream_id);
100+ #else
101+ PADDLE_THROW (
102+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
83103 " Cannot visit device synchronize." ));
84104#endif
85105 },
@@ -100,12 +120,167 @@ void BindXpuStream(py::module *m_ptr) {
100120#endif
101121 });
102122
123+ py::class_<phi::XPUStreamHandle>(m, " XPUStream" , R"DOC(
124+ The handle of the XPU stream.
125+
126+ Parameters:
127+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
128+ If device is None or negative integer, device will be the current device.
129+ If device is positive integer, it must less than the device count. Default: None.
130+
131+ Examples:
132+ .. code-block:: python
133+
134+ >>> # doctest: +REQUIRES(env:XPU)
135+ >>> import paddle
136+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0))
137+ >>> s2 = paddle.device.xpu.Stream(0)
138+ >>> s3 = paddle.device.xpu.Stream()
139+
140+ )DOC" )
141+ #ifdef PADDLE_WITH_XPU
142+ .def_property_readonly (
143+ " xpu_stream" ,
144+ [](phi::XPUStreamHandle &self) {
145+ return reinterpret_cast <std::uintptr_t >(self.raw_stream ());
146+ })
147+ .def (" wait_stream" ,
148+ [](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other) {
149+ auto *dev_ctx = phi::get_xpu_context ();
150+ dev_ctx->StreamWaitStreamInPool (self.id (), other.id ());
151+ })
152+ .def (" wait_event" ,
153+ [](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
154+ self.wait_event (other.get_event ());
155+ })
156+ .def (" query" ,
157+ [](phi::XPUStreamHandle &self) {
158+ PADDLE_THROW (common::errors::Unavailable (
159+ " Query function for XPUStream is not supported now" ));
160+ })
161+ .def (" record_event" ,
162+ [](phi::XPUStreamHandle &self, phi::XPUEventHandle *event) {
163+ if (event == nullptr ) {
164+ event = new phi::XPUEventHandle ();
165+ }
166+ self.record_event (event->get_event ());
167+ return event;
168+ })
169+ .def (
170+ " synchronize" ,
171+ [](phi::XPUStreamHandle &self) { self.synchronize (); },
172+ R"DOC(
173+ Waits for stream tasks to complete.
174+
175+ Examples:
176+ .. code-block:: python
177+
178+ >>> # doctest: +REQUIRES(env:XPU)
179+ >>> import paddle
180+ >>> s = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
181+ >>> s.synchronize()
182+
183+ )DOC" )
184+ .def_property_readonly (
185+ " place" ,
186+ [](phi::XPUStreamHandle &self) {
187+ return phi::XPUPlace (platform::GetXPUCurrentDeviceId ());
188+ })
189+ .def_property_readonly (
190+ " idx" , [](phi::XPUStreamHandle &self) { return self.id (); })
191+ #endif
192+
193+ .def (" __init__" ,
194+ [](phi::XPUStreamHandle &self) {
195+ #ifdef PADDLE_WITH_XPU
196+ new (&self) phi::XPUStreamHandle ();
197+ self.Init ();
198+ #else
199+ PADDLE_THROW (common::errors::Unavailable (
200+ " Class XPUStream can only be initialized on the XPU "
201+ " platform." ));
202+ #endif
203+ })
204+ .def (
205+ " __init__" ,
206+ [](phi::XPUStreamHandle &self, phi::XPUPlace *place) {
207+ #ifdef PADDLE_WITH_XPU
208+ if (place == nullptr ) {
209+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
210+ auto place_tmp = phi::XPUPlace (curr_device_id);
211+ new (&self) phi::XPUStreamHandle (place_tmp);
212+ } else {
213+ new (&self) phi::XPUStreamHandle (*place);
214+ }
215+ #else
216+ PADDLE_THROW (common::errors::Unavailable (
217+ " Class XPUStream can only be initialized on the XPU "
218+ " platform." ));
219+ #endif
220+ },
221+ py::arg (" device" ) = nullptr )
222+ .def (
223+ " __init__" ,
224+ [](phi::XPUStreamHandle &self, int device) {
225+ #ifdef PADDLE_WITH_XPU
226+ if (device < 0 ) {
227+ device = platform::GetXPUCurrentDeviceId ();
228+ }
229+ auto place_tmp = phi::XPUPlace (device);
230+ new (&self) phi::XPUStreamHandle (place_tmp);
231+ #else
232+ PADDLE_THROW (common::errors::Unavailable (
233+ " Class XPUStream can only be initialized on the XPU "
234+ " platform." ));
235+ #endif
236+ },
237+ py::arg (" device" ) = -1 );
238+ py::class_<phi::XPUEventHandle>(m, " XPUEvent" , R"DOC(
239+ The handle of the XPU event.
240+
241+ Examples:
242+ .. code-block:: python
243+
244+ >>> # doctest: +REQUIRES(env:XPU)
245+ >>> import paddle
246+ >>> event = paddle.device.xpu.Event()
247+
248+ )DOC" )
249+ #ifdef PADDLE_WITH_XPU
250+ .def (
251+ " record" ,
252+ [](phi::XPUEventHandle &self, phi::XPUStreamHandle *stream) {
253+ if (stream == nullptr ) {
254+ auto *dev_ctx = phi::get_xpu_context ();
255+ auto stream_handle = dev_ctx->get_current_stream_handle ();
256+ self.record (stream_handle->raw_stream ());
257+ } else {
258+ self.record (stream->raw_stream ());
259+ }
260+ },
261+ py::arg (" stream" ) = nullptr )
262+ .def (" query" , [](phi::XPUEventHandle &self) { return self.query (); })
263+ .def (" elapsed_time" ,
264+ [](phi::XPUEventHandle &self) {
265+ PADDLE_THROW (common::errors::Unavailable (
266+ " XPUEvent elapsed_time is not supported now" ));
267+ })
268+ .def (" synchronize" , [](phi::XPUEventHandle &self) { self.synchronize (); })
269+ #endif
270+ .def (" __init__" , [](phi::XPUEventHandle &self) {
271+ #ifdef PADDLE_WITH_XPU
272+ new (&self) phi::XPUEventHandle ();
273+ #else
274+ PADDLE_THROW (common::errors::Unavailable (
275+ " Class XPUEvent can only be initialized on the XPU platform." ));
276+ #endif
277+ });
103278#ifdef PADDLE_WITH_XPU
104- py::class_<XPUStream >(m, " XPUStream " , R"DOC(
105- The handle of the CUDA stream.
279+ py::class_<phi::XPUCUDAStream >(m, " XPUCUDAStream " , R"DOC(
280+ The handle of the XPU stream.
106281
107282 Parameters:
108- device(paddle.CUDAPlace ()|int|None, optional): The device which wanted to allocate the stream.
283+ device(paddle.XPUPlace ()|int|None, optional): The device which wanted to allocate the stream.
109284 If device is None or negative integer, device will be the current device.
110285 If device is positive integer, it must less than the device count. Default: None.
111286 priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -114,16 +289,16 @@ void BindXpuStream(py::module *m_ptr) {
114289 Examples:
115290 .. code-block:: python
116291
117- >>> # doctest: +REQUIRES(env:GPU )
292+ >>> # doctest: +REQUIRES(env:XPU )
118293 >>> import paddle
119- >>> s1 = paddle.device.cuda .Stream(paddle.CUDAPlace (0), 1)
120- >>> s2 = paddle.device.cuda .Stream(0, 1)
121- >>> s3 = paddle.device.cuda .Stream()
294+ >>> s1 = paddle.device.xpu .Stream(paddle.XPUPlace (0), 1)
295+ >>> s2 = paddle.device.xpu .Stream(0, 1)
296+ >>> s3 = paddle.device.xpu .Stream()
122297
123298 )DOC" )
124299 .def (
125300 " synchronize" ,
126- [](XPUStream &self) { xpu_wait ( self); },
301+ [](phi::XPUCUDAStream &self) { self. Synchronize ( ); },
127302 R"DOC(
128303 Waits for stream tasks to complete.
129304
@@ -135,7 +310,25 @@ void BindXpuStream(py::module *m_ptr) {
135310 >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136311 >>> s.synchronize()
137312
138- )DOC" );
313+ )DOC" )
314+ .def (" __init__" ,
315+ [](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
316+ if (priority != 1 && priority != 2 ) {
317+ PADDLE_THROW (common::errors::InvalidArgument (
318+ " Priority should be 1(high) or 2(normal) " ));
319+ }
320+ auto stream_flag =
321+ phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking ;
322+ if (place == nullptr ) {
323+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
324+ auto place_tmp = phi::XPUPlace (curr_device_id);
325+ new (&self)
326+ phi::XPUCUDAStream (place_tmp, priority - 2 , stream_flag);
327+ } else {
328+ new (&self)
329+ phi::XPUCUDAStream (*place, priority - 2 , stream_flag);
330+ }
331+ });
139332#endif
140333}
141334} // namespace pybind
0 commit comments