From 4e4be3b38bf405f09bea5e3a65931d470b226810 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 22 Sep 2025 10:02:55 +0000 Subject: [PATCH 1/8] Done Signed-off-by: Jee Jee Li --- vllm/lora/layers/base_linear.py | 10 +++++----- vllm/lora/layers/column_parallel_linear.py | 10 +++++----- vllm/lora/layers/logits_processor.py | 8 ++++---- vllm/lora/layers/vocal_parallel_embedding.py | 10 ++++++---- vllm/lora/lora_weights.py | 4 ++-- vllm/lora/models.py | 19 +++++++++---------- 6 files changed, 31 insertions(+), 30 deletions(-) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 85a1f86ce6bf..6cf5815ef12d 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -121,18 +121,18 @@ def set_lora( lora_bias = self.slice_bias(lora_bias) self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) + 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if lora_bias is not None: self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) + lora_bias, non_blocking=True) def apply(self, x: torch.Tensor, diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 658fd23165da..a09f3c87806d 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -285,12 +285,12 @@ def set_lora( for i in range(self.n_slices): if (lora_a_i := lora_a[i]) is not None: self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) + index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_( + lora_a_i, non_blocking=True) if (lora_b_i := lora_b[i]) is not None: self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) + index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_( + lora_b_i, non_blocking=True) if lora_bias is not None: self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], @@ -299,7 +299,7 @@ def set_lora( if (lora_bias_i := lora_bias[i]) is not None: self.lora_bias_stacked[i][index, 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, + lora_bias_i, non_blocking=True) @classmethod diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index a50dcfa748f2..b8fbad3a4af0 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -140,11 +140,11 @@ def set_lora( ): self.reset_lora(index) self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) + 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 4d6218d97097..6ae7de8d33f2 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -95,11 +95,13 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) + # NOTE self.lora_a_stacked is row-major, and lora_a is row-major, + # so we need transpose here + self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index 958364fca592..e3198fb3d3ae 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -86,11 +86,11 @@ def create_dummy_lora_weights( embeddings_tensor_dim: Optional[int] = None, bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() - lora_a = torch.zeros([input_dim, rank], + lora_a = torch.zeros([rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], + lora_b = torch.zeros([output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9ea46be65cff..31bcaa9e4770 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -132,6 +132,8 @@ def from_lora_tensors( pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): + if "lm_head" in tensor_name: + pass module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( tensor_name, weights_mapper) if module_name not in loras: @@ -152,30 +154,29 @@ def from_lora_tensors( module_name, peft_helper, lora_embeddings_tensor) if is_bias: - loras[module_name].bias = tensor.to(device=device, - dtype=dtype).t() - bias = tensor.to(device=device, dtype=dtype).t() + loras[module_name].bias = tensor.to(device=device, dtype=dtype) + bias = tensor.to(device=device, dtype=dtype) if pin_memory: bias = bias.pin_memory() loras[module_name].bias = bias elif is_lora_a: loras[module_name].lora_a = tensor.to(device=device, - dtype=dtype).t() + dtype=dtype) if pin_memory: loras[module_name].lora_a = loras[ module_name].lora_a.pin_memory() else: loras[module_name].lora_b = tensor.to(device=device, - dtype=dtype).t() + dtype=dtype) assert embedding_padding_modules is not None if any(name in module_name for name in embedding_padding_modules ) and target_embedding_padding is not None: lora_b = loras[module_name].lora_b - assert target_embedding_padding >= lora_b.shape[1] - addition = target_embedding_padding - lora_b.shape[1] + assert target_embedding_padding >= lora_b.shape[0] + addition = target_embedding_padding - lora_b.shape[0] loras[module_name].lora_b = torch.nn.functional.pad( - lora_b, (0, addition)) + lora_b, (0, 0, 0, addition)) if pin_memory: loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() @@ -585,7 +586,6 @@ def create_dummy_lora( "cpu", bias_enabled=bias_enabled, ) - lora.optimize() else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] @@ -600,7 +600,6 @@ def create_dummy_lora( "cpu", bias_enabled=bias_enabled, ) - lora.optimize() subloras.append(lora) lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora From 3c40d7bb0b0907bd56ca57b1706c906f9f7ef61d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 22 Sep 2025 16:12:41 +0000 Subject: [PATCH 2/8] Move forward Signed-off-by: Jee Jee Li --- vllm/lora/layers/column_parallel_linear.py | 57 +++++++++++----------- vllm/lora/layers/row_parallel_linear.py | 4 +- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index a09f3c87806d..44f9389355b1 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -99,13 +99,13 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: if self.is_merged_col_linear: tp_rank = get_tensor_model_parallel_rank() shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 + offset = lora_b.shape[0] // 2 - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) + left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * + shard_size,:] + right_weight = lora_b[offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size,:] + lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. else: @@ -113,7 +113,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: shard_size = self.output_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_b = lora_b[start_idx:end_idx,:] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @@ -251,9 +251,8 @@ def slice_lora_b( for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] + sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size * + (shard_id + 1),:] return sliced_lora_b def slice_bias( @@ -345,18 +344,18 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() self.q_shard_id = tp_rank self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * + lora_b_q = lora_b[self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] + (self.q_shard_id + 1),:] k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + + lora_b_k = lora_b[k_offset + self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] + self.kv_proj_shard_size * (self.kv_shard_id + 1),:] v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + + lora_b_v = lora_b[ v_offset + self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + self.kv_proj_shard_size * (self.kv_shard_id + 1),:] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @@ -465,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] + lora_a = lora_a[start_idx:start_idx + shard_size,:] return lora_a def apply(self, @@ -508,10 +507,10 @@ def slice_lora_a( output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, + lora_a[0][ output_start_idx:output_start_idx + + output_shard_size,:] if lora_a[0] is not None else None, + lora_a[1][output_start_idx:output_start_idx + + output_shard_size,:] if lora_a[1] is not None else None, ] return lora_a @@ -551,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] + lora_a = lora_a[start_idx:start_idx + shard_size,:] return lora_a def apply(self, @@ -589,12 +588,12 @@ def slice_lora_a( shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, + lora_a[0][start_idx[0]:start_idx[0] + + shard_size[0],:] if lora_a[0] is not None else None, + lora_a[1][start_idx[1]:start_idx[1] + + shard_size[1],:] if lora_a[1] is not None else None, + lora_a[2][start_idx[2]:start_idx[2] + + shard_size[2],:] if lora_a[2] is not None else None, ] return lora_a diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 18ef6fd1ddd7..cac2c92136dc 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -39,7 +39,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: shard_size = self.input_size start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = lora_a[:,start_idx:end_idx] return lora_a def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: @@ -122,7 +122,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: shard_size = self.lora_b_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_b = lora_b[ start_idx:end_idx,:] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: From 4fc32091209b2e1b4700e3552522512adf69fe35 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 22 Sep 2025 16:16:35 +0000 Subject: [PATCH 3/8] FMT Signed-off-by: Jee Jee Li --- vllm/lora/layers/column_parallel_linear.py | 32 +++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 44f9389355b1..fa4eb272a69f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -102,9 +102,9 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: offset = lora_b.shape[0] // 2 left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * - shard_size,:] + shard_size, :] right_weight = lora_b[offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size,:] + (tp_rank + 1) * shard_size, :] lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. @@ -113,7 +113,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: shard_size = self.output_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[start_idx:end_idx,:] + lora_b = lora_b[start_idx:end_idx, :] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: @@ -252,7 +252,7 @@ def slice_lora_b( zip(self.output_ids, self.output_slices)): if (lora_b_i := lora_b[i]) is not None: sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size * - (shard_id + 1),:] + (shard_id + 1), :] return sliced_lora_b def slice_bias( @@ -346,15 +346,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas lora_b_q = lora_b[self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1),:] + (self.q_shard_id + 1), :] k_offset = self.q_proj_total_size lora_b_k = lora_b[k_offset + self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1),:] + self.kv_proj_shard_size * (self.kv_shard_id + 1), :] v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[ v_offset + + lora_b_v = lora_b[v_offset + self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1),:] + self.kv_proj_shard_size * (self.kv_shard_id + 1), :] lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b @@ -464,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[start_idx:start_idx + shard_size,:] + lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a def apply(self, @@ -507,10 +507,10 @@ def slice_lora_a( output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[0][ output_start_idx:output_start_idx + - output_shard_size,:] if lora_a[0] is not None else None, + lora_a[0][output_start_idx:output_start_idx + + output_shard_size, :] if lora_a[0] is not None else None, lora_a[1][output_start_idx:output_start_idx + - output_shard_size,:] if lora_a[1] is not None else None, + output_shard_size, :] if lora_a[1] is not None else None, ] return lora_a @@ -550,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size - lora_a = lora_a[start_idx:start_idx + shard_size,:] + lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a def apply(self, @@ -589,11 +589,11 @@ def slice_lora_a( start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ lora_a[0][start_idx[0]:start_idx[0] + - shard_size[0],:] if lora_a[0] is not None else None, + shard_size[0], :] if lora_a[0] is not None else None, lora_a[1][start_idx[1]:start_idx[1] + - shard_size[1],:] if lora_a[1] is not None else None, + shard_size[1], :] if lora_a[1] is not None else None, lora_a[2][start_idx[2]:start_idx[2] + - shard_size[2],:] if lora_a[2] is not None else None, + shard_size[2], :] if lora_a[2] is not None else None, ] return lora_a From 8f1b7b7cb7b289f2abc6f9c5d190aa1db385d351 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 02:14:41 +0000 Subject: [PATCH 4/8] Fix test Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 26 ++++++++++++++------------ tests/lora/utils.py | 8 ++++---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6735b7cd9e43..ced0afc50cb9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -164,8 +164,8 @@ def populate_loras( weight=layer_weights, generate_embeddings_tensor=generate_embeddings_tensor, ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] + sublora.lora_b = sublora.lora_b[(sublora_len * + i):(sublora_len * (i + 1)), :] sublora.optimize() subloras.append(sublora) @@ -304,9 +304,9 @@ def create_random_embedding_layer(): result = embedding(input_) after_a = F.embedding( input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += (after_a @ lora.lora_b.T) expected_results.append(result) expected_result = torch.cat(expected_results) @@ -445,9 +445,9 @@ def create_random_embedding_layer(): result = expanded_embedding(input_) after_a = F.embedding( original_input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += (after_a @ lora.lora_b.T) expected_results.append(result) expected_result = torch.cat(expected_results) @@ -575,7 +575,7 @@ def _pretest(): lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) logits_processor.org_vocab_size = vocab_size @@ -692,9 +692,10 @@ def create_random_linear_replicated_layer(): expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -817,7 +818,7 @@ def create_random_linear_parallel_layer(): for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -965,9 +966,10 @@ class FakeConfig: result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) + result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] * + (i + 1)] += ( + input_ @ sublora.lora_a.T @ sublora.lora_b.T * + sublora.scaling) expected_results.append(result) expected_result = torch.cat(expected_results) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ab475904d493..0432a1a9bba0 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -36,10 +36,10 @@ def init_random_lora( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([weight.shape[1], rank], + lora_a=torch.rand([rank, weight.shape[1]], dtype=weight.dtype, device=self._device), - lora_b=torch.rand([rank, weight.shape[0]], + lora_b=torch.rand([weight.shape[0], rank], dtype=weight.dtype, device=self._device), ) @@ -67,8 +67,8 @@ def init_lora( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([rank, input_dim], device="cuda"), + lora_b=torch.rand([output_dim, input_dim], device="cuda"), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) From fad5ba6ecd549a58032f978d6500eacedd6a2a9e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 03:49:57 +0000 Subject: [PATCH 5/8] Fix test Signed-off-by: Jee Jee Li --- tests/lora/test_lora_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index d7684fbf34ab..19007cae0599 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.lora_b is not None assert lora.lora_a.device == torch.device(device) assert lora.lora_b.device == torch.device(device) - assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + assert (lora.lora_a.shape[0] == lora.lora_b.shape[1] ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" - assert lora.lora_a.shape[1] == 8 + assert lora.lora_a.shape[0] == 8 embeddings_module = next( (k for k in EMBEDDING_MODULES if k in module_name), None) if embeddings_module: @@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0]], device=device), + torch.rand([8,w.shape[1]], device=device), + torch.rand([w.shape[0],8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -109,8 +109,8 @@ def create_packed_lora( replaced_module_name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0] // len(replaced_module_names)], + torch.rand([8,w.shape[1]], device=device), + torch.rand([w.shape[0] // len(replaced_module_names),8], device=device), ) return LoRAModel(lora_id, 8, loras) From 048d6b571575222038fd745df89e1a59b3e9d88b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 05:34:45 +0000 Subject: [PATCH 6/8] Fix format Signed-off-by: Jee Jee Li --- tests/lora/test_lora_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 19007cae0599..6f0a85231408 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], name, 8, 16, - torch.rand([8,w.shape[1]], device=device), - torch.rand([w.shape[0],8], device=device), + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0], 8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -109,8 +109,8 @@ def create_packed_lora( replaced_module_name, 8, 16, - torch.rand([8,w.shape[1]], device=device), - torch.rand([w.shape[0] // len(replaced_module_names),8], + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device), ) return LoRAModel(lora_id, 8, loras) From af92bb22d9f505464a19e611afb5d9cc3a82a5c1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 06:41:25 +0000 Subject: [PATCH 7/8] Fix Signed-off-by: Jee Jee Li --- vllm/lora/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 31bcaa9e4770..cc64cc78affa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -132,8 +132,6 @@ def from_lora_tensors( pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - if "lm_head" in tensor_name: - pass module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( tensor_name, weights_mapper) if module_name not in loras: From 97d7d92bde9d6f949c1259a96e42dfd769dfea84 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 06:48:04 +0000 Subject: [PATCH 8/8] Fix comments Signed-off-by: Jee Jee Li --- vllm/lora/layers/vocal_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 6ae7de8d33f2..ca01c7e17fff 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -95,7 +95,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) - # NOTE self.lora_a_stacked is row-major, and lora_a is row-major, + # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, # so we need transpose here self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True)