File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -479,7 +479,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
479479 # ColumnParallelLinear.
480480 else :
481481 tensor_model_parallel_rank = get_tensor_model_parallel_rank ()
482- shard_size = self .output_dim
482+ shard_size = self .output_size
483483 start_idx = tensor_model_parallel_rank * shard_size
484484 end_idx = (tensor_model_parallel_rank + 1 ) * shard_size
485485 lora_b = lora_b [:, start_idx :end_idx ]
@@ -490,7 +490,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
490490 if bias is None :
491491 return bias
492492 tensor_model_parallel_rank = get_tensor_model_parallel_rank ()
493- shard_size = self .output_dim
493+ shard_size = self .output_size
494494 start_idx = tensor_model_parallel_rank * shard_size
495495 end_idx = (tensor_model_parallel_rank + 1 ) * shard_size
496496 bias = bias [start_idx :end_idx ]
You can’t perform that action at this time.
0 commit comments