diff --git a/tests/ut/distributed/kv_transfer/test_simple_buffer.py b/tests/ut/distributed/kv_transfer/test_simple_buffer.py new file mode 100644 index 0000000000..6f90df923f --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_buffer.py @@ -0,0 +1,71 @@ +import unittest +import zlib +from unittest.mock import MagicMock + +import torch + +from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer, + int32_hash) + + +class MockSimplePipe: + + def __init__(self): + self.cluster_id = 0 + self.send_tensor = MagicMock() + self.recv_tensor = MagicMock() + self.deallocate_buffer = MagicMock() + + +class TestSimpleBuffer(unittest.TestCase): + + def setUp(self): + self.pipe = MockSimplePipe() + self.buffer = SimpleBuffer(self.pipe) + + def test_int32_hash(self): + self.assertEqual(int32_hash("test"), zlib.adler32(b"test")) + + def test_insert(self): + input_tokens = torch.tensor([1, 2, 3]) + roi = torch.tensor([1, 0, 1]) + key = torch.randn(2, 3, 4, 5) + value = torch.randn(2, 3, 4, 5) + hidden = torch.randn(3, 6) + + self.buffer.num_layers = 2 + self.buffer.num_heads = 4 + self.buffer.head_size = 5 + self.buffer.hidden_size = 6 + self.buffer.dtype = torch.float32 + + self.buffer.insert(input_tokens, roi, key, value, hidden, "req1") + + self.pipe.send_tensor.assert_called() + + def test_drop_select(self): + input_tokens = torch.tensor([1, 2, 3]) + roi = None + + self.buffer.num_layers = 2 + self.buffer.num_heads = 4 + self.buffer.head_size = 5 + self.buffer.hidden_size = 6 + self.buffer.dtype = torch.float32 + + self.pipe.recv_tensor.side_effect = [ + (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), + (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), + (MagicMock(), torch.randn(1, 3, 6)) + ] + + result = self.buffer.drop_select(input_tokens, roi, "req1") + self.assertEqual(len(result), 4) + self.assertIsInstance(result[0], torch.Tensor) + self.assertIsInstance(result[1], torch.Tensor) + self.assertIsInstance(result[2], torch.Tensor) + self.assertIsNone(result[3]) + self.assertEqual(result[0].shape, (2, 3, 4, 5)) + + def test_close(self): + self.buffer.close() diff --git a/tests/ut/distributed/kv_transfer/test_simple_connector.py b/tests/ut/distributed/kv_transfer/test_simple_connector.py new file mode 100644 index 0000000000..ac6c4d478d --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_connector.py @@ -0,0 +1,146 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.config import VllmConfig +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer +from vllm_ascend.distributed.kv_transfer.simple_connector import \ + SimpleConnector +from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe + + +class TestSimpleConnector(unittest.TestCase): + + def setUp(self): + self.mock_pipe = MagicMock(spec=SimplePipe) + self.mock_buffer = MagicMock(spec=SimpleBuffer) + + patcher = patch( + 'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer') + self.addCleanup(patcher.stop) + self.MockSimpleBuffer = patcher.start() + self.MockSimpleBuffer.return_value = self.mock_buffer + + def _create_mock_config(self, kv_role): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + mock_config.kv_port = 5500 + self.mock_config = MagicMock(spec=VllmConfig) + self.mock_config.kv_transfer_config.is_kv_producer = True + self.mock_config.model_config.hf_config.hidden_size = 128 + self.mock_config.model_config.hf_config.num_attention_heads = 8 + self.mock_config.model_config.hf_config.num_key_value_heads = 8 + self.mock_config.model_config.hf_config.qk_rope_head_dim = 16 + self.mock_config.model_config.hf_config.kv_lora_rank = 16 + self.mock_config.model_config.is_deepseek_mla = True + # 模拟 parallel_config + self.mock_config.parallel_config = MagicMock() + self.mock_config.parallel_config.tensor_parallel_size = 1 + self.mock_config.parallel_config.get_num_layers.return_value = 4 + + if kv_role == "kv_producer": + self.mock_config.kv_transfer_config.kv_role = "kv_producer" + else: + self.mock_config.kv_transfer_config.kv_role = "kv_consumer" + return mock_config + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist): + """Test select method when buffer retrieval succeeds.""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + assert connector.producer_data_pipe is not None + assert connector.producer_buffer is not None + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist): + + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_consumer")) + connector.consumer_data_pipe = mock_pipe + connector.consumer_buffer = mock_buffer + assert connector.consumer_data_pipe is not None + assert connector.consumer_buffer is not None + input_tokens = torch.tensor([1, 2, 3]) + roi = torch.tensor([True, True, True]) + req_id = "test_req" + connector.select(input_tokens, roi, req_id) + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist): + """Test insert operation""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + + connector.producer_buffer = mock_buffer + + input_tokens = torch.randint(0, 1000, (5, )) + roi = torch.ones_like(input_tokens, dtype=torch.bool) + keys = torch.randn(3, 5, 1, 96) + values = torch.randn(3, 5, 1, 96) + hidden = torch.randn(5, 768) + req_id = "test_req" + + connector.insert(input_tokens, roi, keys, values, hidden, req_id) + + mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys, + values, hidden, req_id) + + @patch.object(SimpleConnector, 'insert') + @patch('torch.distributed.get_rank', return_value=0) + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer, + MockLLMDataDist, mock_insert, + mock_rank): + """Test sending KV caches and hidden states""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + + mock_model_executable = MagicMock() + mock_model_executable.model.start_layer = 0 + mock_model_executable.model.end_layer = 3 + + mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata) + mock_model_input.input_tokens = torch.randint(0, 1000, (10, )) + mock_model_input.attn_metadata.seq_lens = [5, 5] + mock_model_input.attn_metadata.slot_mapping = torch.randint( + 0, 100, (10, )) + mock_model_input.attn_metadata.num_prefill_tokens = 10 + mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]} + + kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)] + + hidden_states = torch.randn(10, 768) + + connector.send_kv_caches_and_hidden_states(mock_model_executable, + mock_model_input, kv_caches, + hidden_states) diff --git a/tests/ut/distributed/kv_transfer/test_simple_pipe.py b/tests/ut/distributed/kv_transfer/test_simple_pipe.py new file mode 100644 index 0000000000..efd6eddea8 --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_pipe.py @@ -0,0 +1,145 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe + + +class TestSimplePipe(unittest.TestCase): + + @classmethod + def _create_mock_config(self): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + mock_config.kv_port = 5500 + return mock_config + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_init_success(self, mock_thread, MockLLMDataDist): + + mock_config = self._create_mock_config() + + self.pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + + self.pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_prepare_data_dist(self, mock_thread, MockLLMDataDist): + self.pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=self._create_mock_config(), + hostname="127.0.0.1", + port_offset=0) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + self.pipe.router_socket.close() + + def test_init_with_invalid_kv_role(self): + with self.assertRaises(NotImplementedError): + mock_config = MagicMock() + mock_config.kv_role = "err_role" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + pipe.router_socket.close() + + def test_init_with_missing_device_ips(self): + with self.assertRaises(ValueError): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=0, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_create_register_thread_address_is_empty(self, MockThread, + MockLLMDataDist): + + mock_config = self._create_mock_config() + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + self.assertIsNotNone(pipe._register_thread) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_create_register_thread_address_is_not_empty( + self, MockThread, MockLLMDataDist): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": [""], + "decode_device_ips": [""], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + self.assertIsNotNone(pipe._register_thread) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + pipe.router_socket.close() + + @patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe') + @patch('llm_datadist.LLMDataDist') + def test_should_send_tensor_when_valid_input(self, MockSimplePipe, + MockLLMDataDist): + pipe = MockSimplePipe() + tensor = torch.randn(3, 3) + tensor_desc = MockLLMDataDist.CacheDesc( + num_tensors=1, + shape=(3, 3), + data_type=MockLLMDataDist.DataType.DT_FLOAT, + seq_len_dim_index=1) + tensor_key = MockLLMDataDist.CacheKey(1, 0, 1) + result = pipe.send_tensor(tensor, tensor_desc, tensor_key) + self.assertIsNotNone(result)