From 1248e8506a4d98b4f15cbfe729cf2af42fb4223a Mon Sep 17 00:00:00 2001 From: Wenxiang <8460860+wenxcs@users.noreply.github.com> Date: Sat, 31 Aug 2024 03:42:57 +0800 Subject: [PATCH] [Model] Adding support for MSFT Phi-3.5-MoE (#7729) Co-authored-by: Your Name Co-authored-by: Zeqi Lin Co-authored-by: Zeqi Lin --- docs/source/models/supported_models.rst | 4 + tests/models/test_phimoe.py | 111 ++++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ .../layers/fused_moe/fused_moe.py | 19 +- vllm/model_executor/layers/fused_moe/layer.py | 90 ++- .../compressed_tensors_moe.py | 24 +- .../layers/quantization/experts_int8.py | 26 +- .../model_executor/layers/quantization/fp8.py | 26 +- .../model_executor/layers/rotary_embedding.py | 26 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/phimoe.py | 620 ++++++++++++++++++ 13 files changed, 1255 insertions(+), 82 deletions(-) create mode 100644 tests/models/test_phimoe.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/models/phimoe.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f727c646b7da7..2c20b6e48407d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -147,6 +147,10 @@ Decoder-only Language Models - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + * - :code:`PhiMoEForCausalLM` + - Phi-3.5-MoE + - :code:`microsoft/Phi-3.5-MoE-instruct`, etc. + - * - :code:`PersimmonForCausalLM` - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. diff --git a/tests/models/test_phimoe.py b/tests/models/test_phimoe.py new file mode 100644 index 0000000000000..2fb2eecc94672 --- /dev/null +++ b/tests/models/test_phimoe.py @@ -0,0 +1,111 @@ +"""Compare the outputs of HF and vLLM for moe models using greedy sampling. + +Run `pytest tests/models/test_phimoe.py`. +""" +import pytest +import torch + +from vllm.utils import is_cpu + +from .utils import check_logprobs_close + +MODELS = [ + "microsoft/Phi-3.5-MoE-instruct", +] + + +def test_phimoe_routing_function(): + from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { + 0: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.1, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + }, + 1: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.4, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + } + } + + ground_truth = { + 0: { + "topk_weights": + torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + }, + 1: { + "topk_weights": + torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + } + } + + for test_id in test_case: + topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) + assert torch.allclose(topk_weights, + ground_truth[test_id]["topk_weights"]) + assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) + + +def get_gpu_memory(): + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + gpu_memory = props.total_memory / (1024**3) + return gpu_memory + except Exception: + return 0 + + +@pytest.mark.skipif(condition=is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif(condition=get_gpu_memory() < 100, + reason="Skip this test if GPU memory is insufficient.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..cd0cdbea0c337 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..ba9041d008507 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..57055453aa24c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3584": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11e..05169eaddb256 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch import triton @@ -446,7 +446,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, topk: int, - renormalize: bool, + custom_routing_function: Optional[Callable] = None, + renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -497,8 +498,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor, E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) get_config_func = functools.partial(try_get_optimal_moe_config, w1.shape, @@ -695,6 +700,7 @@ def fused_moe( use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -742,9 +748,12 @@ def fused_moe( topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk, renormalize, num_expert_group, topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) return fused_experts(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61ebef5e11f43..3df0b61a9ebe4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -62,15 +62,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -79,17 +82,21 @@ def apply(self, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, - num_expert_group=num_expert_group) - - def forward_cuda(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -101,7 +108,8 @@ def forward_cuda(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -114,20 +122,24 @@ def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") - def forward_tpu(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None + assert custom_routing_function is None return fused_moe(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -172,6 +184,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", + custom_routing_function: Optional[Callable] = None, ): super().__init__() @@ -190,6 +203,7 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + self.custom_routing_function = custom_routing_function if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -390,7 +404,8 @@ def select_experts(hidden_states: torch.Tensor, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None): + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -405,11 +420,17 @@ def select_experts(hidden_states: torch.Tensor, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) return topk_weights, topk_ids @@ -426,7 +447,8 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, - num_expert_group=self.num_expert_group) + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0e0ab9ce9169f..36323493d601e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,6 +1,6 @@ import enum from enum import Enum -from typing import List, Optional +from typing import Callable, List, Optional import torch @@ -256,15 +256,18 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, ) replace_tensor("w2_weight_scale", marlin_w2_scales) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_marlin_moe) @@ -278,6 +281,7 @@ def apply(self, layer.w13_g_idx_sort_indices, layer.w2_g_idx_sort_indices, top_k, + custom_routing_function, renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fef..116a4ea0aed89 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch @@ -96,15 +96,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_scale", w2_scale) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -114,7 +117,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1817dbcb023a7..32affe06b89b7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module @@ -468,15 +468,18 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -487,7 +490,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 0562b71aa7493..c5a0278e485d4 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -503,8 +503,8 @@ def __init__( dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.0, - long_mscale: float = 1.0, + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, ): super().__init__() @@ -523,18 +523,22 @@ def __init__( self.base = base self.short_factor = short_factor self.long_factor = long_factor - self.short_mscale = short_mscale - self.long_mscale = long_mscale - - scale = (self.max_position_embeddings / - self.original_max_position_embeddings) + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings if scale <= 1.0: - self.scaling_factor = 1.0 + scaling_factor = 1.0 else: - self.scaling_factor = math.sqrt( + scaling_factor = math.sqrt( 1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) @@ -571,8 +575,8 @@ def _compute_cos_sin_cache( inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale * self.scaling_factor - sin = freqs.sin() * mscale * self.scaling_factor + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale cache = torch.cat((cos, sin), dim=-1) return cache diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fc3d4922aea09..f4c3e43c8f2a4 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -50,6 +50,7 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py new file mode 100644 index 0000000000000..c8128052a3ebe --- /dev/null +++ b/vllm/model_executor/models/phimoe.py @@ -0,0 +1,620 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only PhiMoE model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput + +from .interfaces import SupportsLoRA + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class mp(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expaned.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expaned, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, + keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, + float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + custom_routing_function=phimoe_routing_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=None, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=None, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +class PhiMoEModel(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = PhiMoEModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=None, + bias=True, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)