@@ -33,19 +33,55 @@ namespace py = pybind11;
3333namespace paddle {
3434namespace platform {
3535#ifdef PADDLE_WITH_XPU
36- XPUStream get_current_stream (int device_id) {
37- if (device_id == -1 ) {
38- device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
39- }
36+ phi::XPUStreamHandle* get_current_stream (int device_id) {
4037 auto place = phi::XPUPlace (device_id);
4138 auto *dev_ctx = static_cast <phi::XPUContext *>(
4239 phi::DeviceContextPool::Instance ().Get (place));
43- dev_ctx-> Wait ();
44- return dev_ctx-> stream () ;
40+ auto handle = new phi::XPUStreamHandle ();
41+ return handle ;
4542}
4643
44+ // phi::XPUContext* get_context(int device_id) {
45+ // // int curr_device_id = platform::GetXPUCurrentDeviceId();
46+ // auto place_tmp = phi::XPUPlace(device_id > -1 ? device_id : platform::GetXPUCurrentDeviceId());
47+ // phi::XPUContext *dev_ctx = static_cast<phi::XPUContext *>(
48+ // phi::DeviceContextPool::Instance().Get(place_tmp));
49+
50+ // return dev_ctx;
51+ // }
52+
53+ phi::XPUStreamHandle* set_current_stream (int idx) {
54+ int device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
55+ auto original_stream = get_current_stream (device_id);
56+ auto place = phi::XPUPlace (device_id);
57+ auto *dev_ctx = static_cast <phi::XPUContext *>(
58+ phi::DeviceContextPool::Instance ().Get (place));
59+ dev_ctx->SetCurrentStream (idx);
60+ // return original_stream;
61+ return original_stream;
62+ }
63+
64+
65+ // XPUStream get_stream_by_idx
66+ // ::paddle::phi::XPUCUDAStream get_current_cuda_stream(int device_id) {
67+ // if (device_id == -1) {
68+ // device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
69+ // }
70+ // auto place = phi::XPUPlace(device_id);
71+
72+ // }
73+
4774#endif
4875} // namespace platform
76+
77+
78+ // namespace phi{
79+ // void phi::XPUEventHandle::synchronize() {
80+ // auto *dev_ctx = paddle::platform::get_context();
81+ // dev_ctx->StreamWaitEvent(event_, 0);
82+ // }
83+ // }
84+
4985namespace pybind {
5086void BindXpuStream (py::module *m_ptr) {
5187 auto &m = *m_ptr;
@@ -69,7 +105,7 @@ void BindXpuStream(py::module *m_ptr) {
69105#endif
70106 });
71107 m.def (
72- " _get_current_stream " ,
108+ " _xpu_get_current_stream " ,
73109 [](int device_id) {
74110#ifdef PADDLE_WITH_XPU
75111 if (device_id == -1 ) {
@@ -79,11 +115,18 @@ void BindXpuStream(py::module *m_ptr) {
79115 return platform::get_current_stream (device_id);
80116#else
81117 PADDLE_THROW (
82- common::errors::Unavailable (" Paddle is not compiled with CUDA . "
118+ common::errors::Unavailable (" Paddle is not compiled with XPU . "
83119 " Cannot visit device synchronize." ));
84120#endif
85121 },
86122 py::return_value_policy::reference);
123+ #ifdef PADDLE_WITH_XPU
124+ m.def (" _xpu_set_current_stream" ,
125+ [](int stream_id) {
126+ return platform::set_current_stream (stream_id);
127+ },
128+ py::return_value_policy::reference);
129+ #endif
87130 m.def (" _device_synchronize" , [](int device_id) {
88131#ifdef PADDLE_WITH_XPU
89132 if (device_id == -1 ) {
@@ -101,11 +144,190 @@ void BindXpuStream(py::module *m_ptr) {
101144 });
102145
103146#ifdef PADDLE_WITH_XPU
104- py::class_<XPUStream >(m, " XPUStream" , R"DOC(
147+ py::class_<phi::XPUStreamHandle >(m, " XPUStream" , R"DOC(
105148 The handle of the CUDA stream.
106149
107150 Parameters:
108- device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
151+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
152+ If device is None or negative integer, device will be the current device.
153+ If device is positive integer, it must less than the device count. Default: None.
154+ priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
155+ If priority is None, the priority is 2(normal). Default: None.
156+
157+ Examples:
158+ .. code-block:: python
159+
160+ >>> # doctest: +REQUIRES(env:GPU)
161+ >>> import paddle
162+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
163+ >>> s2 = paddle.device.xpu.Stream(0, 1)
164+ >>> s3 = paddle.device.xpu.Stream()
165+
166+ )DOC" )
167+ .def (
168+ " __init__" ,
169+ [](phi::XPUStreamHandle &self) {
170+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
171+ // auto place_tmp = phi::XPUPlace(curr_device_id);
172+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
173+ // phi::DeviceContextPool::Instance().Get(place_tmp));
174+ auto *dev_ctx = phi::get_xpu_context ();
175+ // new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
176+ new (&self) phi::XPUStreamHandle ();
177+ // self.idx = dev_ctx->get_idle_stream();
178+ // printf("empty init, idx: %d \n", self.id());
179+ // self.id =
180+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
181+ // if (dev_ctx->stream() == nullptr) {
182+ // dev_ctx->SetStream(self);
183+ // }
184+ }
185+ )
186+ .def_property_readonly (
187+ " xpu_stream" ,
188+ [](phi::XPUStreamHandle &self) {
189+ // printf("call pybind xpu_stream()\n");
190+ // return self.raw_stream();
191+ return self;
192+ }
193+ )
194+ .def (
195+ " wait_stream" ,
196+ [](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other){
197+ // XPUEvent event;
198+ auto *dev_ctx = phi::get_xpu_context ();
199+ dev_ctx->StreamWaitStreamInPool (self.id (), other.id ());
200+ // dev_ctx->StreamWaitStream()
201+ }
202+ )
203+ // .def(
204+ // "__init__",
205+ // [](XPUStream &self, phi::XPUPlace *place) {
206+ // // int device_count = platform::GetXPUDeviceCount();
207+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
208+ // phi::DeviceContextPool::Instance().Get(place_tmp));
209+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
210+ // if (place == nullptr) {
211+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
212+ // auto place_tmp = phi::XPUPlace(curr_device_id);
213+ // }
214+
215+ // if (dev_ctx->stream() == nullptr) {
216+ // dev_ctx->SetStream(self);
217+ // }
218+ // }
219+ // )
220+ .def (
221+ " __init__" ,
222+ [](phi::XPUStreamHandle &self, int device) {
223+ // int curr_device_id = platform::GetXPUCurrentDeviceId();
224+ auto place_tmp = phi::XPUPlace (device);
225+ auto *dev_ctx = static_cast <phi::XPUContext *>(
226+ phi::DeviceContextPool::Instance ().Get (place_tmp));
227+ // auto *dev_ctx = platform::get_context();
228+ new (&self) phi::XPUStreamHandle ();
229+ // new (&self) phi::XPUStreamHandle(dev_ctx->get_idle_stream());
230+ // new (&self) phi::XPUStreamHandle;
231+ // self.idx = dev_ctx->get_idle_stream();
232+ // printf("device init, idx: %d \n", self.id());
233+
234+ // int device_count = platform::GetXPUDeviceCount();
235+ // if (device < 0) {
236+ // device = platform::GetXPUCurrentDeviceId();
237+ // }
238+ // if (device >= device_count) {
239+ // PADDLE_THROW(common::errors::InvalidArgument(
240+ // "The device id must be inside [0, %d), but input device=%d.",
241+ // device_count,
242+ // device));
243+ // }
244+ // auto place_tmp = phi::XPUPlace(device);
245+ // auto *dev_ctx = static_cast<phi::XPUContext *>(
246+ // phi::DeviceContextPool::Instance().Get(place_tmp));
247+ // new (&self) XPUStream;
248+ // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&self));
249+ // if (dev_ctx->stream() == nullptr) {
250+ // dev_ctx->SetStream(self);
251+ // }
252+ },
253+ py::arg (" device" ) = -1
254+ )
255+ .def_property_readonly (
256+ " place" ,
257+ [](phi::XPUStreamHandle &self) { return phi::XPUPlace (platform::GetXPUCurrentDeviceId ()); })
258+ .def_property_readonly (
259+ " idx" ,
260+ [](phi::XPUStreamHandle &self) { return self.id (); }
261+ )
262+ // .def(
263+ // "synchronize",
264+ // [](XPUStream &self) { xpu_wait(self); },
265+ // R"DOC(
266+ // Waits for stream tasks to complete.
267+
268+ // Examples:
269+ // .. code-block:: python
270+
271+ // >>> # doctest: +REQUIRES(env:GPU)
272+ // >>> import paddle
273+ // >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
274+ // >>> s.synchronize()
275+
276+ // )DOC"
277+ // )
278+ ;
279+ py::class_<phi::XPUEventHandle>(m, " XPUEvent" , R"DOC(
280+ The handle of the CUDA event.
281+
282+ Parameters:
283+ enable_timing(bool, optional): Whether the event will measure time. Default: False.
284+ blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
285+ interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
286+
287+ Examples:
288+ .. code-block:: python
289+
290+ >>> # doctest: +REQUIRES(env:GPU)
291+ >>> import paddle
292+ >>> event = paddle.device.cuda.Event()
293+
294+ )DOC" )
295+ .def (
296+ " __init__" ,
297+ [](phi::XPUEventHandle &self) {
298+ new (&self) phi::XPUEventHandle ();
299+ }
300+ )
301+ .def (
302+ " record" ,
303+ [](phi::XPUEventHandle &self, phi::XPUStreamHandle* stream) {
304+ auto *dev_ctx = phi::get_xpu_context ();
305+ XPUStream raw_stream = dev_ctx->get_stream_from_pool (stream->id ());
306+ int r = xpu_event_record (self.get_event (), raw_stream);
307+ PADDLE_ENFORCE_XRE_SUCCESS (r);
308+
309+ },
310+ py::arg (" stream" ) = nullptr
311+ )
312+ .def (
313+ " query" ,
314+ [](phi::XPUEventHandle &self) {return xpu_event_query (self.get_event ());}
315+ )
316+ .def (
317+ " synchronize" ,
318+ [](phi::XPUEventHandle &self) {
319+ auto *dev_ctx = phi::get_xpu_context ();
320+ dev_ctx->StreamWaitEvent (self.get_event (), 0 );
321+ // self.synchronize();
322+ }
323+ );
324+
325+
326+ py::class_<phi::XPUCUDAStream>(m, " XPUCUDAStream" , R"DOC(
327+ The handle of the XPU stream.
328+
329+ Parameters:
330+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
109331 If device is None or negative integer, device will be the current device.
110332 If device is positive integer, it must less than the device count. Default: None.
111333 priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -116,14 +338,14 @@ void BindXpuStream(py::module *m_ptr) {
116338
117339 >>> # doctest: +REQUIRES(env:GPU)
118340 >>> import paddle
119- >>> s1 = paddle.device.cuda .Stream(paddle.CUDAPlace (0), 1)
120- >>> s2 = paddle.device.cuda .Stream(0, 1)
121- >>> s3 = paddle.device.cuda .Stream()
341+ >>> s1 = paddle.device.xpu .Stream(paddle.XPUPlace (0), 1)
342+ >>> s2 = paddle.device.xpu .Stream(0, 1)
343+ >>> s3 = paddle.device.xpu .Stream()
122344
123345 )DOC" )
124346 .def (
125347 " synchronize" ,
126- [](XPUStream &self) { xpu_wait ( self); },
348+ [](phi::XPUCUDAStream &self) { self. Synchronize ( ); },
127349 R"DOC(
128350 Waits for stream tasks to complete.
129351
@@ -135,7 +357,26 @@ void BindXpuStream(py::module *m_ptr) {
135357 >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136358 >>> s.synchronize()
137359
138- )DOC" );
360+ )DOC"
361+ )
362+ .def (
363+ " __init__" ,
364+ [](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
365+ if (priority != 1 && priority != 2 ) {
366+ PADDLE_THROW (common::errors::InvalidArgument (
367+ " Priority should be 1(high) or 2(normal) " ));
368+ }
369+ auto stream_flag = phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking ;
370+ if (place == nullptr ) {
371+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
372+ auto place_tmp = phi::XPUPlace (curr_device_id);
373+ new (&self) phi::XPUCUDAStream (place_tmp, priority - 2 , stream_flag);
374+ } else {
375+ new (&self) phi::XPUCUDAStream (*place, priority - 2 , stream_flag);
376+ }
377+
378+ }
379+ );
139380#endif
140381}
141382} // namespace pybind
0 commit comments