Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/ut/distributed/kv_transfer/test_simple_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest
import zlib
from unittest.mock import MagicMock

import torch

from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
int32_hash)


class MockSimplePipe:

def __init__(self):
self.cluster_id = 0
self.send_tensor = MagicMock()
self.recv_tensor = MagicMock()
self.deallocate_buffer = MagicMock()


class TestSimpleBuffer(unittest.TestCase):

def setUp(self):
self.pipe = MockSimplePipe()
self.buffer = SimpleBuffer(self.pipe)

def test_int32_hash(self):
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))

def test_insert(self):
input_tokens = torch.tensor([1, 2, 3])
roi = torch.tensor([1, 0, 1])
key = torch.randn(2, 3, 4, 5)
value = torch.randn(2, 3, 4, 5)
hidden = torch.randn(3, 6)

self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32

self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")

self.pipe.send_tensor.assert_called()

def test_drop_select(self):
input_tokens = torch.tensor([1, 2, 3])
roi = None

self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32

self.pipe.recv_tensor.side_effect = [
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 3, 6))
]

result = self.buffer.drop_select(input_tokens, roi, "req1")
self.assertEqual(len(result), 4)
self.assertIsInstance(result[0], torch.Tensor)
self.assertIsInstance(result[1], torch.Tensor)
self.assertIsInstance(result[2], torch.Tensor)
self.assertIsNone(result[3])
self.assertEqual(result[0].shape, (2, 3, 4, 5))

def test_close(self):
self.buffer.close()
146 changes: 146 additions & 0 deletions tests/ut/distributed/kv_transfer/test_simple_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import unittest
from unittest.mock import MagicMock, patch

import torch
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_connector import \
SimpleConnector
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe


class TestSimpleConnector(unittest.TestCase):

def setUp(self):
self.mock_pipe = MagicMock(spec=SimplePipe)
self.mock_buffer = MagicMock(spec=SimpleBuffer)

patcher = patch(
'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer')
self.addCleanup(patcher.stop)
self.MockSimpleBuffer = patcher.start()
self.MockSimpleBuffer.return_value = self.mock_buffer

def _create_mock_config(self, kv_role):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
mock_config.kv_port = 5500
self.mock_config = MagicMock(spec=VllmConfig)
self.mock_config.kv_transfer_config.is_kv_producer = True
self.mock_config.model_config.hf_config.hidden_size = 128
self.mock_config.model_config.hf_config.num_attention_heads = 8
self.mock_config.model_config.hf_config.num_key_value_heads = 8
self.mock_config.model_config.hf_config.qk_rope_head_dim = 16
self.mock_config.model_config.hf_config.kv_lora_rank = 16
self.mock_config.model_config.is_deepseek_mla = True
# 模拟 parallel_config
self.mock_config.parallel_config = MagicMock()
self.mock_config.parallel_config.tensor_parallel_size = 1
self.mock_config.parallel_config.get_num_layers.return_value = 4

if kv_role == "kv_producer":
self.mock_config.kv_transfer_config.kv_role = "kv_producer"
else:
self.mock_config.kv_transfer_config.kv_role = "kv_consumer"
return mock_config

@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist):
"""Test select method when buffer retrieval succeeds."""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))
assert connector.producer_data_pipe is not None
assert connector.producer_buffer is not None
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None

@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist):

connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_consumer"))
connector.consumer_data_pipe = mock_pipe
connector.consumer_buffer = mock_buffer
assert connector.consumer_data_pipe is not None
assert connector.consumer_buffer is not None
input_tokens = torch.tensor([1, 2, 3])
roi = torch.tensor([True, True, True])
req_id = "test_req"
connector.select(input_tokens, roi, req_id)

@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist):
"""Test insert operation"""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))

connector.producer_buffer = mock_buffer

input_tokens = torch.randint(0, 1000, (5, ))
roi = torch.ones_like(input_tokens, dtype=torch.bool)
keys = torch.randn(3, 5, 1, 96)
values = torch.randn(3, 5, 1, 96)
hidden = torch.randn(5, 768)
req_id = "test_req"

connector.insert(input_tokens, roi, keys, values, hidden, req_id)

mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys,
values, hidden, req_id)

@patch.object(SimpleConnector, 'insert')
@patch('torch.distributed.get_rank', return_value=0)
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer,
MockLLMDataDist, mock_insert,
mock_rank):
"""Test sending KV caches and hidden states"""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))

mock_model_executable = MagicMock()
mock_model_executable.model.start_layer = 0
mock_model_executable.model.end_layer = 3

mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata)
mock_model_input.input_tokens = torch.randint(0, 1000, (10, ))
mock_model_input.attn_metadata.seq_lens = [5, 5]
mock_model_input.attn_metadata.slot_mapping = torch.randint(
0, 100, (10, ))
mock_model_input.attn_metadata.num_prefill_tokens = 10
mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]}

kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)]

hidden_states = torch.randn(10, 768)

connector.send_kv_caches_and_hidden_states(mock_model_executable,
mock_model_input, kv_caches,
hidden_states)
145 changes: 145 additions & 0 deletions tests/ut/distributed/kv_transfer/test_simple_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import unittest
from unittest.mock import MagicMock, patch

import torch

from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe


class TestSimplePipe(unittest.TestCase):

@classmethod
def _create_mock_config(self):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
mock_config.kv_port = 5500
return mock_config

@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_init_success(self, mock_thread, MockLLMDataDist):

mock_config = self._create_mock_config()

self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)

self.pipe.router_socket.close()

@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=self._create_mock_config(),
hostname="127.0.0.1",
port_offset=0)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
self.pipe.router_socket.close()

def test_init_with_invalid_kv_role(self):
with self.assertRaises(NotImplementedError):
mock_config = MagicMock()
mock_config.kv_role = "err_role"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()

Check warning on line 71 in tests/ut/distributed/kv_transfer/test_simple_pipe.py

View check run for this annotation

Codecov / codecov/patch

tests/ut/distributed/kv_transfer/test_simple_pipe.py#L71

Added line #L71 was not covered by tests

def test_init_with_missing_device_ips(self):
with self.assertRaises(ValueError):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=0,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()

Check warning on line 89 in tests/ut/distributed/kv_transfer/test_simple_pipe.py

View check run for this annotation

Codecov / codecov/patch

tests/ut/distributed/kv_transfer/test_simple_pipe.py#L89

Added line #L89 was not covered by tests

@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_empty(self, MockThread,
MockLLMDataDist):

mock_config = self._create_mock_config()
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()

@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_not_empty(
self, MockThread, MockLLMDataDist):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": [""],
"decode_device_ips": [""],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()

@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
@patch('llm_datadist.LLMDataDist')
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
MockLLMDataDist):
pipe = MockSimplePipe()
tensor = torch.randn(3, 3)
tensor_desc = MockLLMDataDist.CacheDesc(
num_tensors=1,
shape=(3, 3),
data_type=MockLLMDataDist.DataType.DT_FLOAT,
seq_len_dim_index=1)
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
self.assertIsNotNone(result)
Loading