Skip to content

Commit dbfc304

Browse files
committed
[XPU] add stream & event unittests
1 parent e1222a1 commit dbfc304

File tree

4 files changed

+94
-18
lines changed

4 files changed

+94
-18
lines changed

paddle/fluid/pybind/xpu_streams_py.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ void BindXpuStream(py::module *m_ptr) {
152152
[](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
153153
self.wait_event(other.get_event());
154154
})
155-
.def("query", [](phi::XPUStreamHandle &self) { return self.query(); })
155+
.def("query",
156+
[](phi::XPUStreamHandle &self) {
157+
PADDLE_THROW(common::errors::Unavailable(
158+
"Query function for XPUStream is not supported now"));
159+
})
156160
.def("record_event",
157161
[](phi::XPUStreamHandle &self, phi::XPUEventHandle *event) {
158162
if (event == nullptr) {
@@ -170,9 +174,9 @@ void BindXpuStream(py::module *m_ptr) {
170174
Examples:
171175
.. code-block:: python
172176
173-
>>> # doctest: +REQUIRES(env:GPU)
177+
>>> # doctest: +REQUIRES(env:XPU)
174178
>>> import paddle
175-
>>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
179+
>>> s = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
176180
>>> s.synchronize()
177181
178182
)DOC")

paddle/phi/backends/xpu/xpu_context.cc

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -645,21 +645,14 @@ void XPUStreamHandle::synchronize() const {
645645

646646
void XPUStreamHandle::set_stream(XPUStream stream_) { stream = stream_; }
647647

648-
bool XPUStreamHandle::query() const {
649-
XPUEvent event = XPUEventPool::Instance().CreateEventFromPool();
648+
void XPUStreamHandle::record_event(XPUEvent event) const {
650649
int r = xpu_event_record(event, stream);
651650
PADDLE_ENFORCE_XRE_SUCCESS(r);
652-
r = xpu_event_query(event);
653-
if (r == XPU_SUCCESS) {
654-
return true;
655-
} else {
656-
return false;
657-
}
658651
}
659652

660-
void XPUStreamHandle::record_event(XPUEvent event) const {
661-
int r = xpu_event_record(event, stream);
662-
PADDLE_ENFORCE_XRE_SUCCESS(r);
653+
XPUStreamHandle get_current_stream_handle(int device_id) {
654+
auto* dev_ctx = get_xpu_context(device_id);
655+
return *dev_ctx->get_current_stream_handle();
663656
}
664657

665658
XPUStreamHandle get_stream_handle(int device_id) {
@@ -736,7 +729,6 @@ XPUEventHandle::XPUEventHandle(XPUStream stream) {
736729

737730
void XPUEventHandle::record(XPUStream stream) {
738731
int r = xpu_event_query(event_);
739-
printf("====r: %d\n", r);
740732
PADDLE_ENFORCE_XRE_SUCCESS(xpu_event_record(event_, stream));
741733
}
742734

paddle/phi/backends/xpu/xpu_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class XPUStreamHandle {
5252
XPUStream raw_stream() const { return stream; }
5353
void wait_event(XPUEvent event) const;
5454
void synchronize() const;
55-
bool query() const;
5655
void record_event(XPUEvent event) const;
5756
void set_stream(XPUStream stream);
5857

@@ -186,6 +185,7 @@ class XPUEventHandle {
186185
XPUEvent event_;
187186
};
188187

188+
XPUStreamHandle get_current_stream_handle(int device_id = -1);
189189
XPUStreamHandle get_stream_handle(int device_id = -1);
190190
void set_current_stream(XPUStreamHandle* s);
191191

test/xpu/test_xpu_stream_event.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import ctypes
1516
import unittest
1617

18+
import numpy as np
19+
1720
import paddle
1821
from paddle.device import xpu
1922

@@ -56,7 +59,6 @@ def test_xpu_stream_synchronize(self):
5659
e2 = paddle.device.xpu.Event()
5760

5861
e1.record(s)
59-
print("1111")
6062
e1.query()
6163
tensor1 = paddle.to_tensor(paddle.rand([1000, 1000]))
6264
tensor2 = paddle.matmul(tensor1, tensor1)
@@ -78,7 +80,85 @@ def test_xpu_stream_wait_event_and_record_event(self):
7880
s2.wait_event(e1)
7981
s2.synchronize()
8082

81-
self.assertTrue(e1.query() and s1.query() and s2.query())
83+
self.assertTrue(e1.query())
84+
85+
86+
class TestXPUEvent(unittest.TestCase):
87+
def test_xpu_event(self):
88+
if paddle.is_compiled_with_xpu():
89+
e = paddle.device.xpu.Event()
90+
self.assertIsNotNone(e)
91+
s = paddle.device.xpu.current_stream()
92+
93+
def test_xpu_event_methods(self):
94+
if paddle.is_compiled_with_xpu():
95+
e = paddle.device.xpu.Event()
96+
s = paddle.device.xpu.current_stream()
97+
event_query_1 = e.query()
98+
tensor1 = paddle.to_tensor(paddle.rand([1000, 1000]))
99+
tensor2 = paddle.matmul(tensor1, tensor1)
100+
s.record_event(e)
101+
e.synchronize()
102+
event_query_2 = e.query()
103+
104+
self.assertTrue(event_query_1)
105+
self.assertTrue(event_query_2)
106+
107+
108+
class TestStreamGuard(unittest.TestCase):
109+
'''
110+
Note:
111+
The asynchronous execution property of XPU Stream can only be tested offline.
112+
'''
113+
114+
def test_stream_guard_normal(self):
115+
if paddle.is_compiled_with_xpu():
116+
s = paddle.device.Stream()
117+
a = paddle.to_tensor(np.array([0, 2, 4], dtype="int32"))
118+
b = paddle.to_tensor(np.array([1, 3, 5], dtype="int32"))
119+
c = a + b
120+
with paddle.device.stream_guard(s):
121+
d = a + b
122+
s.synchronize()
123+
124+
np.testing.assert_array_equal(np.array(c), np.array(d))
125+
126+
def test_stream_guard_default_stream(self):
127+
if paddle.is_compiled_with_xpu():
128+
s1 = paddle.device.current_stream()
129+
with paddle.device.stream_guard(s1):
130+
pass
131+
s2 = paddle.device.current_stream()
132+
133+
self.assertTrue(id(s1.stream_base) == id(s2.stream_base))
134+
135+
def test_set_current_stream_default_stream(self):
136+
if paddle.is_compiled_with_xpu():
137+
cur_stream = paddle.device.current_stream()
138+
new_stream = paddle.device.set_stream(cur_stream)
139+
140+
self.assertTrue(
141+
id(cur_stream.stream_base) == id(new_stream.stream_base)
142+
)
143+
144+
def test_stream_guard_raise_error(self):
145+
if paddle.is_compiled_with_xpu():
146+
147+
def test_not_correct_stream_guard_input():
148+
tmp = np.zeros(5)
149+
with paddle.device.stream_guard(tmp):
150+
pass
151+
152+
self.assertRaises(TypeError, test_not_correct_stream_guard_input)
153+
154+
155+
class TestRawStream(unittest.TestCase):
156+
def test_xpu_stream(self):
157+
if paddle.is_compiled_with_xpu():
158+
xpu_stream = paddle.device.xpu.current_stream().xpu_stream
159+
print(xpu_stream)
160+
self.assertTrue(type(xpu_stream) is int)
161+
ptr = ctypes.c_void_p(xpu_stream)
82162

83163

84164
if __name__ == "__main__":

0 commit comments

Comments
 (0)