Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions csrc/moe/moe_lora_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ __global__ void moe_lora_align_sum_kernel(
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
int topk_num, int32_t* total_tokens_post_pad, int32_t* num_tokens_per_lora, int32_t* adapter_enabled) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;

int lora_id = blockIdx.x;
if (adapter_enabled[lora_id] * num_tokens_per_lora[lora_id] == 0) {
return;
}

Comment on lines +41 to +44
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering why we don’t use activate_lora_ids[lora_idx] to check if LoRA is enabled here, and num_tokens_per_lora[lora_id] to determine if there are any tokens that need to be processed by the LoRA. Is there another consideration?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the intention with adapter_enabled is to allow us to skip the moe_lora_align_sum_kernel if the LoRA model does not have MoE adapters (e.g. a LoRA model which only includes adapters for attn projections).

from my understanding, activate_lora_ids and num_tokens_per_lora only indicate if a LoRA model is being invoked, but does not indicate which specific LoRA kernels are applicable for that model

extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
Expand Down Expand Up @@ -124,9 +128,10 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
torch::Tensor num_tokens_post_pad,
torch::Tensor num_tokens_per_lora,
torch::Tensor adapter_enabled) {
const int topk_num = topk_ids.size(1);

int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size);
int max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size);
Expand Down Expand Up @@ -160,6 +165,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_per_lora.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>());
});
}
4 changes: 3 additions & 1 deletion csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
torch::Tensor num_tokens_post_pad,
torch::Tensor num_tokens_per_lora,
torch::Tensor adapter_enabled);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
Expand Down
4 changes: 3 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" int block_size, int max_loras, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
" Tensor !num_tokens_post_pad,"
" Tensor !num_tokens_per_lora,"
" Tensor !adapter_enabled) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);

#ifndef USE_ROCM
Expand Down
5 changes: 5 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,9 @@ def moe_align_block_size(
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
adapter_enabled: torch.Tensor, # shape [max-loras]
num_experts: int,
block_size: int,
max_loras: int,
Expand All @@ -1809,6 +1812,8 @@ def moe_lora_align_block_size(
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
num_tokens_per_lora,
adapter_enabled,
)


Expand Down
21 changes: 17 additions & 4 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def wrapper(*args, **kwargs):
global_num_experts = layer._lora["global_num_experts"]
expert_map = layer._lora["expert_map"]

(token_lora_mapping, _, _, _, _,
_) = layer.punica_wrapper.token_mapping_meta.meta_args(
(token_lora_mapping, _, num_tokens_per_lora, _, _,
no_lora_flag_cpu) = layer.punica_wrapper.token_mapping_meta.meta_args(
hidden_states.size(0))

config_dtype = _get_config_dtype_str(use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
Expand All @@ -99,7 +100,8 @@ def wrapper(*args, **kwargs):
config = get_config_func(M)
(sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora) = (moe_lora_align_block_size(
curr_topk_ids, token_lora_mapping, config['BLOCK_SIZE_M'],
curr_topk_ids, token_lora_mapping, num_tokens_per_lora, no_lora_flag_cpu,
layer.adapter_enabled, config['BLOCK_SIZE_M'],
global_num_experts, curr_topk_ids.shape[-1], expert_map))

layer._lora["sorted_token_ids_lora"] = sorted_token_ids_lora
Expand Down Expand Up @@ -132,6 +134,7 @@ def wrapper(*args, **kwargs):
max_lora_rank,
top_k,
config,
layer.adapter_enabled,
)

result = func(*args, **kwargs)
Expand Down Expand Up @@ -191,7 +194,7 @@ def wrapper(*args, **kwargs):
intermediate_cache3, intermediate_cache2,
[w2_lora_a_stacked], [w2_lora_b_stacked], topk_weights,
sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora, max_lora_rank, top_k, config,
num_tokens_post_padded_lora, max_lora_rank, top_k, config, layer.adapter_enabled,
True)

result = func(*args, **kwargs)
Expand Down Expand Up @@ -226,6 +229,8 @@ def create_lora_weights(
model_config: Optional[PretrainedConfig] = None,
) -> None:
"""Initializes lora matrices."""
self.adapter_enabled = torch.tensor([0] * (max_loras+1), dtype=torch.int, device=self.device)

self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
Expand Down Expand Up @@ -288,6 +293,9 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)

# flags to track which LoRAs have MoE adapters
self.base_layer.adapter_enabled = self.adapter_enabled

self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
Expand Down Expand Up @@ -324,6 +332,8 @@ def reset_lora(self, index: int):
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0

self.adapter_enabled[index] = 0

def set_lora(
self,
Expand All @@ -334,6 +344,9 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""

self.adapter_enabled[index] = 1

for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
Expand Down
32 changes: 32 additions & 0 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _fused_moe_lora_kernel(
EM,
num_valid_tokens,
num_experts,
lora_ids,
adapter_enabled,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
Expand Down Expand Up @@ -78,6 +80,12 @@ def _fused_moe_lora_kernel(
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)

lora_id = tl.load(lora_ids + lora_idx)
moe_enabled = tl.load(adapter_enabled + lora_idx)
if lora_id == -1 or moe_enabled == 0:
# Early exit for the no-lora case.
return

# calculate pid_m,pid_n
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
Expand Down Expand Up @@ -160,6 +168,13 @@ def _fused_moe_lora(
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
adapter_enabled: torch.Tensor, # shape [max-loras]
# config:Optional[dict[str, Any]],
block_size_m:int,
block_size_n:int,
Expand All @@ -183,6 +198,12 @@ def _fused_moe_lora(
config (_type_): _description_
intermediate_cache1 (torch.Tensor): _description_
"""

assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA
return

assert len(lora_a_stacked) == len(lora_b_stacked)
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
Expand Down Expand Up @@ -242,6 +263,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
Expand Down Expand Up @@ -287,6 +310,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
a_intermediate_cache1.stride(1),
a_intermediate_cache1.stride(2),
w1_lora_b_stacked.stride(0),
Expand Down Expand Up @@ -324,6 +349,13 @@ def _fused_moe_lora_fake(
block_size_n:int,
block_size_k:int,
group_size_m:int,
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
no_moe_lora_flag_cpu: torch.Tensor,
mul_routed_weight:bool=False,
) -> None:
return
Expand Down
6 changes: 6 additions & 0 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def add_shrink(
scale (float): Scaling factor for the operation
"""

# note @gnovack - force input to be contiguous to support eager mode
x = x.contiguous()

x = x.view(-1, x.shape[-1])
lora_shrink(
x,
Expand Down Expand Up @@ -317,6 +320,7 @@ def add_lora_fused_moe(
max_lora_rank: int,
top_k_num: int,
config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
):
fused_moe_lora(
Expand All @@ -330,6 +334,8 @@ def add_lora_fused_moe(
num_tokens_post_padded,
max_lora_rank,
top_k_num,
*self.token_mapping_meta.meta_args(x.size(0)),
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def moe_align_block_size(
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
adapter_enabled: torch.Tensor, # shape [max-loras]
block_size: int,
num_experts: int,
max_loras: int,
Expand Down Expand Up @@ -119,6 +122,9 @@ def moe_lora_align_block_size(
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_tokens_per_lora,
no_lora_flag_cpu,
adapter_enabled,
num_experts,
block_size,
max_loras,
Expand Down