diff --git a/tests/ut/multistream/test_base.py b/tests/ut/multistream/test_base.py new file mode 100644 index 0000000000..4bdd29b8a5 --- /dev/null +++ b/tests/ut/multistream/test_base.py @@ -0,0 +1,32 @@ +from tests.ut.base import TestBase +from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig, + MSEventKey) + + +class Testbase(TestBase): + + def test_ms_event_key(self): + self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0) + self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1) + self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2) + self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3) + self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4) + self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5) + self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6) + self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7) + self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8) + + def test_ms_attention_metadata_split_config_default(self): + config = MSAttentionMetadataSplitConfig() + self.assertEqual(config.num_micro_batches, 2) + self.assertEqual(config.min_total_tokens_to_split, 256) + self.assertEqual(config.min_prefill_tokens_to_split, 64) + + def test_ms_attention_metadata_split_config_custom(self): + config = MSAttentionMetadataSplitConfig( + num_micro_batches=4, + min_total_tokens_to_split=512, + min_prefill_tokens_to_split=128) + self.assertEqual(config.num_micro_batches, 4) + self.assertEqual(config.min_total_tokens_to_split, 512) + self.assertEqual(config.min_prefill_tokens_to_split, 128) diff --git a/tests/ut/multistream/test_metadata.py b/tests/ut/multistream/test_metadata.py new file mode 100644 index 0000000000..79fd703d14 --- /dev/null +++ b/tests/ut/multistream/test_metadata.py @@ -0,0 +1,246 @@ +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamMetadata, + MultiStreamStepMetadata, + split_micro_batches_tensors) + + +class TestMetaData(TestBase): + + def setUp(self): + self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)] + self.test_tensors = torch.randn(100, 1024) + self.test_tensors_dict = { + 'query': torch.randn(100, 1024), + 'key': torch.randn(100, 1024), + 'value': torch.randn(100, 1024) + } + self.split_index = 50 + + mock_stream = MagicMock(spec=torch.npu.Stream) + event_keys = [MagicMock(spec=MSEventKey)] + multistream_config = MagicMock(spec=MultiStreamConfig) + + self.metadata = MultiStreamMetadata( + calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + def test_split_micro_batches_tensors(self): + test_tensors_list_res = split_micro_batches_tensors( + self.test_tensors_list, self.split_index) + test_tensors_res = split_micro_batches_tensors(self.test_tensors, + self.split_index) + keys = ['query', 'key', 'value'] + test_tensors_dict_res = split_micro_batches_tensors( + self.test_tensors_dict, self.split_index, keys) + for i in range(3): + self.assertEqual(len(test_tensors_list_res[i][0]), + self.split_index) + + self.assertEqual( + len(test_tensors_list_res[i][0]) + + len(test_tensors_list_res[i][1]), 100) + + self.assertEqual(len(test_tensors_res[0]), self.split_index) + self.assertEqual( + len(test_tensors_res[0]) + len(test_tensors_res[1]), 100) + + for key in keys: + self.assertEqual(len(test_tensors_dict_res[0][key]), + self.split_index) + self.assertEqual( + len(test_tensors_dict_res[0][key]) + + len(test_tensors_dict_res[1][key]), 100) + + def test_default_init_multistream_step_metadata(self): + metadata = MultiStreamStepMetadata() + self.assertIsNone(metadata.comm_stream) + self.assertIsNone(metadata.before_comm_event) + self.assertIsNone(metadata.after_comm_event) + + def test_custom_init_multistream_step_metadata(self): + mockStream = MagicMock(spec=torch.npu.Stream) + mockEvent1 = MagicMock(spec=torch.npu.Event) + mockEvent2 = MagicMock(spec=torch.npu.Event) + + metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2) + self.assertEqual(metadata.comm_stream, mockStream) + self.assertEqual(metadata.before_comm_event, mockEvent1) + self.assertEqual(metadata.after_comm_event, mockEvent2) + + def test_default_init_multistream_config(self): + config = MultiStreamConfig() + self.assertEqual(config.min_total_tokens_to_split, 256) + self.assertEqual(config.min_prefill_tokens_to_split, 64) + self.assertEqual(config.num_micro_batches, 2) + self.assertEqual(config.imbalance_ratio, 0.1) + + def test_custom_init_multistream_config(self): + config = MultiStreamConfig(512, 128, 1, 0.2) + self.assertEqual(config.min_total_tokens_to_split, 512) + self.assertEqual(config.min_prefill_tokens_to_split, 128) + self.assertEqual(config.num_micro_batches, 1) + self.assertEqual(config.imbalance_ratio, 0.2) + + def test_init_multistream_metadata(self): + mock_stream = MagicMock(spec=torch.npu.Stream) + + event_keys = [MagicMock()] + multistream_config = MagicMock(spec=MultiStreamConfig) + + metadata = MultiStreamMetadata(calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + self.assertEqual(metadata.calculate_stream, mock_stream) + self.assertEqual(metadata.communicate_stream, mock_stream) + self.assertEqual(metadata.start_layer, 1) + self.assertEqual(metadata.end_layer, 3) + self.assertEqual(metadata.ms_config, multistream_config) + self.assertTrue(metadata.causal_lm) + + def test_build_events(self): + mock_stream = MagicMock(spec=torch.npu.Stream) + mock_event = MagicMock(spec=torch.npu.Event) + with patch('torch.npu.Event', return_value=mock_event): + event_keys = [MagicMock(spec=MSEventKey)] + multistream_config = MultiStreamConfig( + num_micro_batches=2, + min_total_tokens_to_split=256, + min_prefill_tokens_to_split=64) + + metadata = MultiStreamMetadata( + calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + expected_events = { + 0: { + 0: { + event_keys[0]: mock_event + }, + 1: { + event_keys[0]: mock_event + } + }, + 1: { + 0: { + event_keys[0]: mock_event + }, + 1: { + event_keys[0]: mock_event + } + }, + 2: { + 0: { + event_keys[0]: mock_event + }, + 1: { + event_keys[0]: mock_event + } + } + } + self.assertEqual(metadata.ms_events, expected_events) + + def test_build_ms_split_config(self): + mock_stream = MagicMock(spec=torch.npu.Stream) + event_keys = [MagicMock(spec=MSEventKey)] + multistream_config = MagicMock(spec=MultiStreamConfig) + multistream_config.num_micro_batches = 2 + multistream_config.min_total_tokens_to_split = 256 + multistream_config.min_prefill_tokens_to_split = 64 + + metadata = MultiStreamMetadata(calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + self.assertIsNotNone(metadata.ms_split_config) + self.assertEqual(metadata.ms_split_config.num_micro_batches, + multistream_config.num_micro_batches) + self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split, + multistream_config.min_total_tokens_to_split) + self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split, + multistream_config.min_prefill_tokens_to_split) + + def test_try_wait_event(self): + mock_stream = MagicMock(spec=torch.npu.Stream) + mock_event = MagicMock(spec=torch.npu.Event) + event_keys = [MagicMock(spec=MSEventKey)] + multistream_config = MagicMock(spec=MultiStreamConfig) + with patch('torch.npu.Event', return_value=mock_event): + metadata = MultiStreamMetadata( + calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + metadata.try_wait_event(layer_index=1, + micro_batch_index=0, + event_key=event_keys[0]) + mock_event.wait.assert_called_once() + + def test_try_record_event(self): + mock_stream = MagicMock(spec=torch.npu.Stream) + mock_event = MagicMock(spec=torch.npu.Event) + event_keys = [MagicMock(spec=MSEventKey)] + multistream_config = MagicMock(spec=MultiStreamConfig) + with patch('torch.npu.Event', return_value=mock_event): + metadata = MultiStreamMetadata( + calculate_stream=mock_stream, + communicate_stream=mock_stream, + start_layer=1, + end_layer=3, + event_keys=event_keys, + multistream_config=multistream_config) + + metadata.try_record_event(layer_index=1, + micro_batch_index=0, + event_key=event_keys[0]) + mock_event.record.assert_called_once() + + def test_merge_batches_none_input(self): + input_tensors = None + result = self.metadata.merge_micro_batches(input_tensors) + self.assertIsNone(result) + + def test_merge_batches_single_tensor_input(self): + input_tensors = [torch.tensor([1, 2, 3])] + result = self.metadata.merge_micro_batches(input_tensors) + self.assertEqual(len(result), 1) + self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3]))) + + def test_merge_batches_list_of_tensors_input(self): + input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])] + result = self.metadata.merge_micro_batches(input_tensors) + self.assertEqual(len(result), 2) + self.assertEqual(result, input_tensors) + + def test_merge_batches_nested_list_input(self): + input_tensors = [[torch.tensor([1, 2]), + torch.tensor([3, 4])], + [torch.tensor([5, 6]), + torch.tensor([7, 8])]] + result = self.metadata.merge_micro_batches(input_tensors) + self.assertEqual(len(result), 2) + self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4]))) + self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8]))) diff --git a/tests/ut/multistream/test_ms_split.py b/tests/ut/multistream/test_ms_split.py new file mode 100644 index 0000000000..e76321a6e5 --- /dev/null +++ b/tests/ut/multistream/test_ms_split.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import (compute_split_seq_index, + model_input_split_v1_mla_attn, + split_attn_int_type, + split_attn_tensor_type) + + +class TestMsSplit(TestBase): + + def test_decode_only(self): + result = compute_split_seq_index( + query_lens=None, + attn_state=AscendAttentionState.DecodeOnly, + num_tokens=10) + self.assertEqual(result, [5, 5]) + + def test_perfect_balance(self): + query_lens = [2, 3, 5] + result = compute_split_seq_index( + query_lens=query_lens, + attn_state=AscendAttentionState.PrefillNoCache, + num_tokens=10) + self.assertEqual(result, [5, 2]) + + def test_imbalance(self): + query_lens = [1, 2, 3, 4] + result = compute_split_seq_index( + query_lens=query_lens, + attn_state=AscendAttentionState.PrefillNoCache, + num_tokens=10) + self.assertEqual(result, [0, 0]) + + def test_query_lens_none(self): + with self.assertRaises(AssertionError): + compute_split_seq_index( + query_lens=None, + attn_state=AscendAttentionState.PrefillNoCache, + num_tokens=10) + + def test_empty_query_lens(self): + query_lens: list[int] = [] + result = compute_split_seq_index( + query_lens=query_lens, + attn_state=AscendAttentionState.PrefillNoCache, + num_tokens=10) + self.assertEqual(result, [0, 0]) + + def test_single_query_len(self): + query_lens = [10] + result = compute_split_seq_index( + query_lens=query_lens, + attn_state=AscendAttentionState.PrefillNoCache, + num_tokens=10) + self.assertEqual(result, [0, 0]) + + def test_split_attn_tensor_type_middle(self): + input_tensor = torch.tensor([1, 2, 3, 4, 5]) + index = 3 + expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])] + result = split_attn_tensor_type(input_tensor, index) + self.assertEqual(len(result), 2) + self.assertTrue(torch.equal(result[0], expected_result[0])) + self.assertTrue(torch.equal(result[1], expected_result[1])) + + def test_split_attn_tensor_type_start(self): + input_tensor = torch.tensor([1, 2, 3, 4, 5]) + index = 0 + expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])] + result = split_attn_tensor_type(input_tensor, index) + self.assertEqual(len(result), 2) + self.assertTrue(torch.equal(result[0], expected_result[0])) + self.assertTrue(torch.equal(result[1], expected_result[1])) + + def test_split_attn_tensor_type_end(self): + input_tensor = torch.tensor([1, 2, 3, 4, 5]) + index = 5 + expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])] + result = split_attn_tensor_type(input_tensor, index) + self.assertEqual(len(result), 2) + self.assertTrue(torch.equal(result[0], expected_result[0])) + self.assertTrue(torch.equal(result[1], expected_result[1])) + + def test_split_attn_tensor_type_empty_tensor(self): + input_tensor = torch.tensor([]) + index = 0 + expected_result = [torch.tensor([]), torch.tensor([])] + result = split_attn_tensor_type(input_tensor, index) + self.assertEqual(len(result), 2) + self.assertTrue(torch.equal(result[0], expected_result[0])) + self.assertTrue(torch.equal(result[1], expected_result[1])) + + def test_split_attn_int_type_index_greater_than_var(self): + var = 5 + index = 10 + expected_result = [5, 0] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_attn_int_type_index_equal_to_var(self): + var = 5 + index = 5 + expected_result = [5, 0] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_attn_int_type_index_less_than_var(self): + var = 10 + index = 5 + expected_result = [5, 5] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_attn_int_type_index_zero(self): + var = 10 + index = 0 + expected_result = [0, 10] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_attn_int_type_var_zero(self): + var = 0 + index = 5 + expected_result = [0, 0] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_attn_int_type_both_zero(self): + var = 0 + index = 0 + expected_result = [0, 0] + result = split_attn_int_type(var, index) + self.assertEqual(result, expected_result) + + def test_split_v1_mla_attn_input_none(self): + attn_metadata = None + ascendMLAPrefillMetadata = MagicMock() + ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1) + result = model_input_split_v1_mla_attn(attn_metadata, + ascendMLAPrefillMetadata, + ms_split_config) + self.assertEqual(result, [None])