From 85baec6c090c07f01d8d293fc5b4e55476a0ef80 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 20 Feb 2025 18:12:02 +0000 Subject: [PATCH 1/2] Optimize moe intermediate_cache allocation Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 543c8ced165a..b917f865b43a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1173,15 +1173,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) + cache = torch.empty(M * topk_ids.shape[1] * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache[:M * topk_ids.shape[1] * N].view( + (M, topk_ids.shape[1], N)) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = cache[:M * topk_ids.shape[1] * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 From ab14d0efbfb0dd2bbace2d81893522f4be735c2f Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 20 Feb 2025 19:59:15 +0000 Subject: [PATCH 2/2] Improvement Signed-off-by: mgoin --- .../model_executor/layers/fused_moe/fused_moe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b917f865b43a..3d4a8f72cb6b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1173,16 +1173,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - cache = torch.empty(M * topk_ids.shape[1] * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache1 = cache[:M * topk_ids.shape[1] * N].view( + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * topk_ids.shape[1] * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * topk_ids.shape[1] * N].view( (M, topk_ids.shape[1], N)) + intermediate_cache3 = cache13[:M * topk_ids.shape[1] * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) + + # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = cache[:M * topk_ids.shape[1] * w2.shape[1]].view( - (M, topk_ids.shape[1], w2.shape[1])) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16