diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu index 1d80f14b9f19..9ab31f9ae38f 100644 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -33,11 +33,15 @@ __global__ void moe_lora_align_sum_kernel( int64_t block_size, int num_experts, int max_loras, size_t numel, int max_num_tokens_padded, int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad) { + int topk_num, int32_t* total_tokens_post_pad, int32_t* num_tokens_per_lora, int32_t* adapter_enabled) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; int lora_id = blockIdx.x; + if (adapter_enabled[lora_id] * num_tokens_per_lora[lora_id] == 0) { + return; + } + extern __shared__ int32_t shared_mem[]; int32_t* cumsum = shared_mem; token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); @@ -124,9 +128,10 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, int64_t max_loras, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad) { + torch::Tensor num_tokens_post_pad, + torch::Tensor num_tokens_per_lora, + torch::Tensor adapter_enabled) { const int topk_num = topk_ids.size(1); - int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1); max_num_tokens_padded = round_up(max_num_tokens_padded, block_size); int max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size); @@ -160,6 +165,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, max_loras, topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr()); + num_tokens_post_pad.data_ptr(), num_tokens_per_lora.data_ptr(), + adapter_enabled.data_ptr()); }); } \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index c238a9a289db..ecc8c882b08c 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -19,7 +19,9 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, int64_t max_loras, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad); + torch::Tensor num_tokens_post_pad, + torch::Tensor num_tokens_per_lora, + torch::Tensor adapter_enabled); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7ae18bb568a2..7fc9071d14fa 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -31,7 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " int block_size, int max_loras, " " Tensor !sorted_token_ids," " Tensor !experts_ids," - " Tensor !num_tokens_post_pad) -> () "); + " Tensor !num_tokens_post_pad," + " Tensor !num_tokens_per_lora," + " Tensor !adapter_enabled) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 16dd1a1ec881..993e29110c77 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1793,6 +1793,9 @@ def moe_align_block_size( def moe_lora_align_block_size( topk_ids: torch.Tensor, token_lora_mapping: torch.Tensor, + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + adapter_enabled: torch.Tensor, # shape [max-loras] num_experts: int, block_size: int, max_loras: int, @@ -1809,6 +1812,8 @@ def moe_lora_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + num_tokens_per_lora, + adapter_enabled, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 58b4400c09b3..009538cc52f4 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -74,9 +74,10 @@ def wrapper(*args, **kwargs): global_num_experts = layer._lora["global_num_experts"] expert_map = layer._lora["expert_map"] - (token_lora_mapping, _, _, _, _, - _) = layer.punica_wrapper.token_mapping_meta.meta_args( + (token_lora_mapping, _, num_tokens_per_lora, _, _, + no_lora_flag_cpu) = layer.punica_wrapper.token_mapping_meta.meta_args( hidden_states.size(0)) + config_dtype = _get_config_dtype_str(use_fp8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, @@ -99,7 +100,8 @@ def wrapper(*args, **kwargs): config = get_config_func(M) (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = (moe_lora_align_block_size( - curr_topk_ids, token_lora_mapping, config['BLOCK_SIZE_M'], + curr_topk_ids, token_lora_mapping, num_tokens_per_lora, no_lora_flag_cpu, + layer.adapter_enabled, config['BLOCK_SIZE_M'], global_num_experts, curr_topk_ids.shape[-1], expert_map)) layer._lora["sorted_token_ids_lora"] = sorted_token_ids_lora @@ -132,6 +134,7 @@ def wrapper(*args, **kwargs): max_lora_rank, top_k, config, + layer.adapter_enabled, ) result = func(*args, **kwargs) @@ -191,7 +194,7 @@ def wrapper(*args, **kwargs): intermediate_cache3, intermediate_cache2, [w2_lora_a_stacked], [w2_lora_b_stacked], topk_weights, sorted_token_ids_lora, expert_ids_lora, - num_tokens_post_padded_lora, max_lora_rank, top_k, config, + num_tokens_post_padded_lora, max_lora_rank, top_k, config, layer.adapter_enabled, True) result = func(*args, **kwargs) @@ -226,6 +229,8 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: """Initializes lora matrices.""" + self.adapter_enabled = torch.tensor([0] * (max_loras+1), dtype=torch.int, device=self.device) + self.w1_lora_a_stacked = torch.zeros( ( max_loras, @@ -288,6 +293,9 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) + + # flags to track which LoRAs have MoE adapters + self.base_layer.adapter_enabled = self.adapter_enabled self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked @@ -324,6 +332,8 @@ def reset_lora(self, index: int): self.w3_lora_b_stacked[index] = 0 self.w2_lora_a_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0 + + self.adapter_enabled[index] = 0 def set_lora( self, @@ -334,6 +344,9 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): """Overwrites lora tensors at index.""" + + self.adapter_enabled[index] = 1 + for eid in range(len(lora_a) // 3): w1_lora_a = lora_a[eid * 3] w2_lora_a = lora_a[eid * 3 + 1] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 243bb0bf8d75..240cd8d18bf1 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -47,6 +47,8 @@ def _fused_moe_lora_kernel( EM, num_valid_tokens, num_experts, + lora_ids, + adapter_enabled, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down @@ -78,6 +80,12 @@ def _fused_moe_lora_kernel( slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) + lora_id = tl.load(lora_ids + lora_idx) + moe_enabled = tl.load(adapter_enabled + lora_idx) + if lora_id == -1 or moe_enabled == 0: + # Early exit for the no-lora case. + return + # calculate pid_m,pid_n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -160,6 +168,13 @@ def _fused_moe_lora( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + adapter_enabled: torch.Tensor, # shape [max-loras] # config:Optional[dict[str, Any]], block_size_m:int, block_size_n:int, @@ -183,6 +198,12 @@ def _fused_moe_lora( config (_type_): _description_ intermediate_cache1 (torch.Tensor): _description_ """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA + return + assert len(lora_a_stacked) == len(lora_b_stacked) device = qcurr_hidden_states.device num_slices = len(lora_a_stacked) @@ -242,6 +263,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(1), w1_lora_a_stacked.stride(0), @@ -287,6 +310,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, a_intermediate_cache1.stride(1), a_intermediate_cache1.stride(2), w1_lora_b_stacked.stride(0), @@ -324,6 +349,13 @@ def _fused_moe_lora_fake( block_size_n:int, block_size_k:int, group_size_m:int, + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + no_moe_lora_flag_cpu: torch.Tensor, mul_routed_weight:bool=False, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index a4a69a2a94bb..a5c916ce8bd1 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -92,6 +92,9 @@ def add_shrink( scale (float): Scaling factor for the operation """ + # note @gnovack - force input to be contiguous to support eager mode + x = x.contiguous() + x = x.view(-1, x.shape[-1]) lora_shrink( x, @@ -317,6 +320,7 @@ def add_lora_fused_moe( max_lora_rank: int, top_k_num: int, config, + adapter_enabled: torch.Tensor, mul_routed_weight=False, ): fused_moe_lora( @@ -330,6 +334,8 @@ def add_lora_fused_moe( num_tokens_post_padded, max_lora_rank, top_k_num, + *self.token_mapping_meta.meta_args(x.size(0)), + adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 880bcb43e9c6..70400aea49f8 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -89,6 +89,9 @@ def moe_align_block_size( def moe_lora_align_block_size( topk_ids: torch.Tensor, token_lora_mapping: torch.Tensor, + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + adapter_enabled: torch.Tensor, # shape [max-loras] block_size: int, num_experts: int, max_loras: int, @@ -119,6 +122,9 @@ def moe_lora_align_block_size( ops.moe_lora_align_block_size( topk_ids, token_lora_mapping, + num_tokens_per_lora, + no_lora_flag_cpu, + adapter_enabled, num_experts, block_size, max_loras,