Skip to content

Commit ab7d5ac

Browse files
[Test] Add ut for files in /multistream (#1947)
### What this PR does / why we need it? Add some uts for files in folder /multistream ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@b77c7d3 Signed-off-by: lwq <liwenquan5@huawei.com> Co-authored-by: lwq <liwenquan5@huawei.com>
1 parent 34571ea commit ab7d5ac

File tree

3 files changed

+425
-0
lines changed

3 files changed

+425
-0
lines changed

tests/ut/multistream/test_base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from tests.ut.base import TestBase
2+
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
3+
MSEventKey)
4+
5+
6+
class Testbase(TestBase):
7+
8+
def test_ms_event_key(self):
9+
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
10+
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
11+
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
12+
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
13+
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
14+
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
15+
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
16+
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
17+
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
18+
19+
def test_ms_attention_metadata_split_config_default(self):
20+
config = MSAttentionMetadataSplitConfig()
21+
self.assertEqual(config.num_micro_batches, 2)
22+
self.assertEqual(config.min_total_tokens_to_split, 256)
23+
self.assertEqual(config.min_prefill_tokens_to_split, 64)
24+
25+
def test_ms_attention_metadata_split_config_custom(self):
26+
config = MSAttentionMetadataSplitConfig(
27+
num_micro_batches=4,
28+
min_total_tokens_to_split=512,
29+
min_prefill_tokens_to_split=128)
30+
self.assertEqual(config.num_micro_batches, 4)
31+
self.assertEqual(config.min_total_tokens_to_split, 512)
32+
self.assertEqual(config.min_prefill_tokens_to_split, 128)
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.multistream.base import MSEventKey
7+
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
8+
MultiStreamMetadata,
9+
MultiStreamStepMetadata,
10+
split_micro_batches_tensors)
11+
12+
13+
class TestMetaData(TestBase):
14+
15+
def setUp(self):
16+
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
17+
self.test_tensors = torch.randn(100, 1024)
18+
self.test_tensors_dict = {
19+
'query': torch.randn(100, 1024),
20+
'key': torch.randn(100, 1024),
21+
'value': torch.randn(100, 1024)
22+
}
23+
self.split_index = 50
24+
25+
mock_stream = MagicMock(spec=torch.npu.Stream)
26+
event_keys = [MagicMock(spec=MSEventKey)]
27+
multistream_config = MagicMock(spec=MultiStreamConfig)
28+
29+
self.metadata = MultiStreamMetadata(
30+
calculate_stream=mock_stream,
31+
communicate_stream=mock_stream,
32+
start_layer=1,
33+
end_layer=3,
34+
event_keys=event_keys,
35+
multistream_config=multistream_config)
36+
37+
def test_split_micro_batches_tensors(self):
38+
test_tensors_list_res = split_micro_batches_tensors(
39+
self.test_tensors_list, self.split_index)
40+
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
41+
self.split_index)
42+
keys = ['query', 'key', 'value']
43+
test_tensors_dict_res = split_micro_batches_tensors(
44+
self.test_tensors_dict, self.split_index, keys)
45+
for i in range(3):
46+
self.assertEqual(len(test_tensors_list_res[i][0]),
47+
self.split_index)
48+
49+
self.assertEqual(
50+
len(test_tensors_list_res[i][0]) +
51+
len(test_tensors_list_res[i][1]), 100)
52+
53+
self.assertEqual(len(test_tensors_res[0]), self.split_index)
54+
self.assertEqual(
55+
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
56+
57+
for key in keys:
58+
self.assertEqual(len(test_tensors_dict_res[0][key]),
59+
self.split_index)
60+
self.assertEqual(
61+
len(test_tensors_dict_res[0][key]) +
62+
len(test_tensors_dict_res[1][key]), 100)
63+
64+
def test_default_init_multistream_step_metadata(self):
65+
metadata = MultiStreamStepMetadata()
66+
self.assertIsNone(metadata.comm_stream)
67+
self.assertIsNone(metadata.before_comm_event)
68+
self.assertIsNone(metadata.after_comm_event)
69+
70+
def test_custom_init_multistream_step_metadata(self):
71+
mockStream = MagicMock(spec=torch.npu.Stream)
72+
mockEvent1 = MagicMock(spec=torch.npu.Event)
73+
mockEvent2 = MagicMock(spec=torch.npu.Event)
74+
75+
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
76+
self.assertEqual(metadata.comm_stream, mockStream)
77+
self.assertEqual(metadata.before_comm_event, mockEvent1)
78+
self.assertEqual(metadata.after_comm_event, mockEvent2)
79+
80+
def test_default_init_multistream_config(self):
81+
config = MultiStreamConfig()
82+
self.assertEqual(config.min_total_tokens_to_split, 256)
83+
self.assertEqual(config.min_prefill_tokens_to_split, 64)
84+
self.assertEqual(config.num_micro_batches, 2)
85+
self.assertEqual(config.imbalance_ratio, 0.1)
86+
87+
def test_custom_init_multistream_config(self):
88+
config = MultiStreamConfig(512, 128, 1, 0.2)
89+
self.assertEqual(config.min_total_tokens_to_split, 512)
90+
self.assertEqual(config.min_prefill_tokens_to_split, 128)
91+
self.assertEqual(config.num_micro_batches, 1)
92+
self.assertEqual(config.imbalance_ratio, 0.2)
93+
94+
def test_init_multistream_metadata(self):
95+
mock_stream = MagicMock(spec=torch.npu.Stream)
96+
97+
event_keys = [MagicMock()]
98+
multistream_config = MagicMock(spec=MultiStreamConfig)
99+
100+
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
101+
communicate_stream=mock_stream,
102+
start_layer=1,
103+
end_layer=3,
104+
event_keys=event_keys,
105+
multistream_config=multistream_config)
106+
107+
self.assertEqual(metadata.calculate_stream, mock_stream)
108+
self.assertEqual(metadata.communicate_stream, mock_stream)
109+
self.assertEqual(metadata.start_layer, 1)
110+
self.assertEqual(metadata.end_layer, 3)
111+
self.assertEqual(metadata.ms_config, multistream_config)
112+
self.assertTrue(metadata.causal_lm)
113+
114+
def test_build_events(self):
115+
mock_stream = MagicMock(spec=torch.npu.Stream)
116+
mock_event = MagicMock(spec=torch.npu.Event)
117+
with patch('torch.npu.Event', return_value=mock_event):
118+
event_keys = [MagicMock(spec=MSEventKey)]
119+
multistream_config = MultiStreamConfig(
120+
num_micro_batches=2,
121+
min_total_tokens_to_split=256,
122+
min_prefill_tokens_to_split=64)
123+
124+
metadata = MultiStreamMetadata(
125+
calculate_stream=mock_stream,
126+
communicate_stream=mock_stream,
127+
start_layer=1,
128+
end_layer=3,
129+
event_keys=event_keys,
130+
multistream_config=multistream_config)
131+
132+
expected_events = {
133+
0: {
134+
0: {
135+
event_keys[0]: mock_event
136+
},
137+
1: {
138+
event_keys[0]: mock_event
139+
}
140+
},
141+
1: {
142+
0: {
143+
event_keys[0]: mock_event
144+
},
145+
1: {
146+
event_keys[0]: mock_event
147+
}
148+
},
149+
2: {
150+
0: {
151+
event_keys[0]: mock_event
152+
},
153+
1: {
154+
event_keys[0]: mock_event
155+
}
156+
}
157+
}
158+
self.assertEqual(metadata.ms_events, expected_events)
159+
160+
def test_build_ms_split_config(self):
161+
mock_stream = MagicMock(spec=torch.npu.Stream)
162+
event_keys = [MagicMock(spec=MSEventKey)]
163+
multistream_config = MagicMock(spec=MultiStreamConfig)
164+
multistream_config.num_micro_batches = 2
165+
multistream_config.min_total_tokens_to_split = 256
166+
multistream_config.min_prefill_tokens_to_split = 64
167+
168+
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
169+
communicate_stream=mock_stream,
170+
start_layer=1,
171+
end_layer=3,
172+
event_keys=event_keys,
173+
multistream_config=multistream_config)
174+
175+
self.assertIsNotNone(metadata.ms_split_config)
176+
self.assertEqual(metadata.ms_split_config.num_micro_batches,
177+
multistream_config.num_micro_batches)
178+
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
179+
multistream_config.min_total_tokens_to_split)
180+
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
181+
multistream_config.min_prefill_tokens_to_split)
182+
183+
def test_try_wait_event(self):
184+
mock_stream = MagicMock(spec=torch.npu.Stream)
185+
mock_event = MagicMock(spec=torch.npu.Event)
186+
event_keys = [MagicMock(spec=MSEventKey)]
187+
multistream_config = MagicMock(spec=MultiStreamConfig)
188+
with patch('torch.npu.Event', return_value=mock_event):
189+
metadata = MultiStreamMetadata(
190+
calculate_stream=mock_stream,
191+
communicate_stream=mock_stream,
192+
start_layer=1,
193+
end_layer=3,
194+
event_keys=event_keys,
195+
multistream_config=multistream_config)
196+
197+
metadata.try_wait_event(layer_index=1,
198+
micro_batch_index=0,
199+
event_key=event_keys[0])
200+
mock_event.wait.assert_called_once()
201+
202+
def test_try_record_event(self):
203+
mock_stream = MagicMock(spec=torch.npu.Stream)
204+
mock_event = MagicMock(spec=torch.npu.Event)
205+
event_keys = [MagicMock(spec=MSEventKey)]
206+
multistream_config = MagicMock(spec=MultiStreamConfig)
207+
with patch('torch.npu.Event', return_value=mock_event):
208+
metadata = MultiStreamMetadata(
209+
calculate_stream=mock_stream,
210+
communicate_stream=mock_stream,
211+
start_layer=1,
212+
end_layer=3,
213+
event_keys=event_keys,
214+
multistream_config=multistream_config)
215+
216+
metadata.try_record_event(layer_index=1,
217+
micro_batch_index=0,
218+
event_key=event_keys[0])
219+
mock_event.record.assert_called_once()
220+
221+
def test_merge_batches_none_input(self):
222+
input_tensors = None
223+
result = self.metadata.merge_micro_batches(input_tensors)
224+
self.assertIsNone(result)
225+
226+
def test_merge_batches_single_tensor_input(self):
227+
input_tensors = [torch.tensor([1, 2, 3])]
228+
result = self.metadata.merge_micro_batches(input_tensors)
229+
self.assertEqual(len(result), 1)
230+
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
231+
232+
def test_merge_batches_list_of_tensors_input(self):
233+
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
234+
result = self.metadata.merge_micro_batches(input_tensors)
235+
self.assertEqual(len(result), 2)
236+
self.assertEqual(result, input_tensors)
237+
238+
def test_merge_batches_nested_list_input(self):
239+
input_tensors = [[torch.tensor([1, 2]),
240+
torch.tensor([3, 4])],
241+
[torch.tensor([5, 6]),
242+
torch.tensor([7, 8])]]
243+
result = self.metadata.merge_micro_batches(input_tensors)
244+
self.assertEqual(len(result), 2)
245+
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
246+
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))

0 commit comments

Comments
 (0)