diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index eff4071a0d7d9..06270c9ddde05 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -69,7 +69,8 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_baichuan_tensor_parallel_equality(baichuan_lora_files): +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") @@ -80,7 +81,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=1, - trust_remote_code=True) + trust_remote_code=True, + fully_sharded_loras=fully_sharded) output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) del llm_tp1 @@ -92,7 +94,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=2, - trust_remote_code=True) + trust_remote_code=True, + fully_sharded_loras=fully_sharded) output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) del llm_tp2 @@ -106,7 +109,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, - trust_remote_code=True) + trust_remote_code=True, + fully_sharded_loras=fully_sharded) output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) del llm_tp4 diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7e782144b07f7..e97666ba6d2ca 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,7 +13,8 @@ from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, @@ -704,7 +705,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLora(linear) + lora_linear = QKVParallelLinearWithLora( + linear + ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) @dataclass class FakeConfig: diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index ffdc32b7339af..d27171f720832 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -12,6 +12,7 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, RowParallelLinearWithLoRA) from vllm.lora.punica import bgmv, dispatch_bgmv_low_level @@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module, def _mcp_apply(x, bias, layer): """ MergedColumnParallelLinearWithShardedLoRA and - QKVParallelLinearWithShardedLora share the same + MergedQKVParallelLinearWithShardedLora share the same LoRa weight application method. The main difference is the step by shard_size for lora_b which can - vary for QKVParallelLinearWithShardedLora but is constant for + vary for MergedQKVParallelLinearWithShardedLora but is constant for MergedColumnParallelLinearWithShardedLoRA. """ # expecting 2 for column parallel and 3 for qkv @@ -167,7 +168,7 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): +class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): """ Differs from QKVParallelLinearWithLora by slicing the LoRA A's also. @@ -175,6 +176,57 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): Based on S-LoRA, slicing happens along the rank dim. """ + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + bgmv(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + # now have column partitioned output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from MergedQKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e3ab1708c3fdf..e4a23273f7282 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -641,6 +641,24 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * self.base_layer.head_size) + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + def set_lora( self, index: int, @@ -650,21 +668,8 @@ def set_lora( ): self.reset_lora(index) if self.tp_size > 1: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size * - self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size * - self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -674,6 +679,7 @@ def set_lora( lora_b.T, non_blocking=True) @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 4a86c16cf64db..ab3b99eee6fc1 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -8,7 +8,8 @@ from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable @@ -35,6 +36,7 @@ RowParallelLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLora, MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA,