Skip to content

Commit 619af55

Browse files
jeejeeleexuebwang-amd
authored andcommitted
[Core] Ensure LoRA linear respect the base_layer's tp_size and tp_rank (vllm-project#25487)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 4dd70dc commit 619af55

File tree

5 files changed

+24
-41
lines changed

5 files changed

+24
-41
lines changed

vllm/lora/layers/base_linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ def __init__(self, base_layer: LinearBase):
2424
super().__init__()
2525
self.base_layer = base_layer
2626
self.input_size = self.base_layer.input_size
27+
# Ensure tp_size and tp_rank consistency with the base_layer.
28+
self.tp_size = self.base_layer.tp_size
29+
self.tp_rank = self.base_layer.tp_rank
2730
self.device = _get_lora_device(self.base_layer)
2831
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
29-
3032
self.output_slices: tuple[int, ...]
31-
self.tp_size: int
3233
self.output_size: int
3334
self.n_slices: int
3435

vllm/lora/layers/column_parallel_linear.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from transformers import PretrainedConfig
99

1010
from vllm.config.lora import LoRAConfig
11-
from vllm.distributed import (get_tensor_model_parallel_rank,
12-
get_tensor_model_parallel_world_size,
13-
tensor_model_parallel_all_gather)
11+
from vllm.distributed import tensor_model_parallel_all_gather
1412
from vllm.distributed.utils import divide
1513
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1614
MergedColumnParallelLinear,
@@ -85,7 +83,6 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None:
8583
# inconsistent when TP is greater than 1.
8684
self.is_merged_col_linear = type(
8785
base_layer) is MergedColumnParallelLinear
88-
self.tp_size = get_tensor_model_parallel_world_size()
8986
self.output_size = self.base_layer.output_size_per_partition
9087
# There is only one LoRA layer
9188
self.n_slices = 1
@@ -97,33 +94,30 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
9794
# Applicable to cases where the base_layer is
9895
# MergedColumnParallelLinear.
9996
if self.is_merged_col_linear:
100-
tp_rank = get_tensor_model_parallel_rank()
10197
shard_size = self.output_size // 2
10298
offset = lora_b.shape[0] // 2
10399

104-
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
100+
left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
105101
shard_size, :]
106-
right_weight = lora_b[offset + tp_rank * shard_size:offset +
107-
(tp_rank + 1) * shard_size, :]
102+
right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
103+
(self.tp_rank + 1) * shard_size, :]
108104
lora_b = torch.cat([left_weight, right_weight], dim=0)
109105
# Applicable to cases where the base_layer is
110106
# ColumnParallelLinear.
111107
else:
112-
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
113108
shard_size = self.output_size
114-
start_idx = tensor_model_parallel_rank * shard_size
115-
end_idx = (tensor_model_parallel_rank + 1) * shard_size
109+
start_idx = self.tp_rank * shard_size
110+
end_idx = (self.tp_rank + 1) * shard_size
116111
lora_b = lora_b[start_idx:end_idx, :]
117112
return lora_b
118113

119114
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
120115
# TODO: Fix the slicing logic of bias.
121116
if bias is None:
122117
return bias
123-
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
124118
shard_size = self.output_size
125-
start_idx = tensor_model_parallel_rank * shard_size
126-
end_idx = (tensor_model_parallel_rank + 1) * shard_size
119+
start_idx = self.tp_rank * shard_size
120+
end_idx = (self.tp_rank + 1) * shard_size
127121
bias = bias[start_idx:end_idx]
128122
return bias
129123

@@ -144,7 +138,7 @@ def forward(
144138

145139
# Matrix multiply.
146140
output_parallel = self.apply(input_, bias)
147-
if self.base_layer.gather_output:
141+
if self.base_layer.gather_output and self.tp_size > 1:
148142
# All-gather across the partitions.
149143
output = tensor_model_parallel_all_gather(output_parallel)
150144
else:
@@ -185,8 +179,6 @@ def __init__(
185179
QKVParallelLinear]) -> None:
186180
super().__init__(base_layer)
187181
# There are two LoRA layers
188-
self.tp_size = get_tensor_model_parallel_world_size()
189-
self.tp_rank = get_tensor_model_parallel_rank()
190182
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
191183
# we need to divide it by the tp_size to get correct slices size
192184
output_sizes = self.base_layer.output_sizes
@@ -341,9 +333,9 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
341333
self.n_slices = 1
342334

343335
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
344-
tp_rank = get_tensor_model_parallel_rank()
345-
self.q_shard_id = tp_rank
346-
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
336+
337+
self.q_shard_id = self.tp_rank
338+
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
347339
lora_b_q = lora_b[self.q_proj_shard_size *
348340
self.q_shard_id:self.q_proj_shard_size *
349341
(self.q_shard_id + 1), :]
@@ -397,8 +389,6 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
397389
super().__init__(base_layer)
398390
# There are three LoRA layer.
399391
self.n_slices = len(self.base_layer.output_sizes)
400-
self.tp_size = get_tensor_model_parallel_world_size()
401-
self.tp_rank = get_tensor_model_parallel_rank()
402392

403393
self.q_proj_shard_size = (self.base_layer.num_heads *
404394
self.base_layer.head_size)
@@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
461451
# Therefore, the sharding of `lora_a` only needs to correspond with the
462452
# gather operation.
463453
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
464-
tp_rank = get_tensor_model_parallel_rank()
465454
shard_size = self.lora_a_stacked[0].shape[2]
466-
start_idx = tp_rank * shard_size
455+
start_idx = self.tp_rank * shard_size
467456
lora_a = lora_a[start_idx:start_idx + shard_size, :]
468457
return lora_a
469458

@@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
547536
"""
548537

549538
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
550-
tp_rank = get_tensor_model_parallel_rank()
551539
shard_size = self.lora_a_stacked[0].shape[2]
552-
start_idx = tp_rank * shard_size
540+
start_idx = self.tp_rank * shard_size
553541
lora_a = lora_a[start_idx:start_idx + shard_size, :]
554542
return lora_a
555543

vllm/lora/layers/replicated_linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
1818
def __init__(self, base_layer: ReplicatedLinear) -> None:
1919
super().__init__(base_layer, )
2020
# To ensure interface compatibility, set to 1 always.
21-
self.tp_size = 1
2221
self.output_size = self.base_layer.output_size
2322
self.n_slices = 1
2423

vllm/lora/layers/row_parallel_linear.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from transformers import PretrainedConfig
99

1010
from vllm.config.lora import LoRAConfig
11-
from vllm.distributed import (get_tensor_model_parallel_rank,
12-
get_tensor_model_parallel_world_size,
13-
split_tensor_along_last_dim,
11+
from vllm.distributed import (split_tensor_along_last_dim,
1412
tensor_model_parallel_all_reduce)
1513
# yapf: disable
1614
from vllm.model_executor.layers.linear import RowParallelLinear
@@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
2523
def __init__(self, base_layer: RowParallelLinear) -> None:
2624
super().__init__(base_layer)
2725

28-
self.tp_size = get_tensor_model_parallel_world_size()
2926
# reset input_size
3027
self.input_size = self.base_layer.input_size_per_partition
3128
self.output_size = self.base_layer.output_size
32-
33-
self.tp_rank = get_tensor_model_parallel_rank()
3429
# There is only one LoRA layer.
3530
self.n_slices = 1
3631

@@ -68,12 +63,12 @@ def forward(
6863
else:
6964
# TODO: simplify code below
7065
splitted_input = split_tensor_along_last_dim(
71-
input_, num_partitions=self.base_layer.tp_size)
66+
input_, num_partitions=self.tp_size)
7267
input_parallel = splitted_input[self.tp_rank].contiguous()
7368

7469
# Matrix multiply.
7570
output_parallel = self.apply(input_parallel)
76-
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
71+
if self.base_layer.reduce_results and self.tp_size > 1:
7772
output_ = tensor_model_parallel_all_reduce(output_parallel)
7873
else:
7974
output_ = output_parallel
@@ -154,8 +149,8 @@ def apply(self,
154149
buffer, x, self.lora_a_stacked, 1.0)
155150
if not current_platform.can_update_inplace():
156151
buffer = shrunk_buffer
157-
158-
buffer = tensor_model_parallel_all_reduce(buffer)
152+
if self.tp_size>1:
153+
buffer = tensor_model_parallel_all_reduce(buffer)
159154

160155
# following S-LoRA, allows the fusing of all_gather and all_reduce
161156
# by adding the column partitioned lora output to a slice of output

vllm/lora/lora_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def optimize(self) -> "LoRALayerWeights":
4848

4949
@property
5050
def input_dim(self) -> int:
51-
return self.lora_a.shape[0]
51+
return self.lora_a.shape[1]
5252

5353
@property
5454
def output_dim(self) -> int:
55-
return self.lora_b.shape[1]
55+
return self.lora_b.shape[0]
5656

5757
@property
5858
def is_packed(self) -> bool:

0 commit comments

Comments
 (0)