diff --git a/run.sh b/run.sh new file mode 100644 index 000000000000..027fadaa14fc --- /dev/null +++ b/run.sh @@ -0,0 +1,21 @@ +# Run config for DeepSeek-R1 on a single 8xH200 node +# Using one MTP module for speculative execution, +# Called recursively for k=2 speculative tokens. +# Expected draft acceptance rate is ~70% +# (~80% for token 1, ~60% for token 2 due to accuracy decay) +python3 \ + -m vllm.entrypoints.openai.api_server \ + --disable-log-requests \ + --gpu-memory-utilization 0.85 \ + --quantization fp8 \ + --max-model-len 65536 \ + --max-num-seqs 128 \ + --seed 0 \ + --tensor-parallel-size 8 \ + --swap-space 0 \ + --block-size 32 \ + --model deepseek-ai/DeepSeek-R1 \ + --distributed-executor-backend=mp \ + --trust-remote-code \ + --num-speculative-tokens 2 \ + --speculative-model DeepSeekV3MTP diff --git a/vllm/config.py b/vllm/config.py index 5579d6936d10..67015bc1afa1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -313,8 +313,7 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.encoder_config = self._get_encoder_config() - self.hf_image_processor_config = get_hf_image_processor_config( - self.model, revision) + self.hf_image_processor_config = {}# get_hf_image_processor_config(self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc self.mm_processor_kwargs = mm_processor_kwargs @@ -420,6 +419,8 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: + return None + architectures = getattr(self.hf_config, "architectures", []) if ModelRegistry.is_multimodal_model(architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) @@ -756,7 +757,7 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: return (hasattr(self.hf_text_config, "model_type")) \ and (self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3'))\ + ('deepseek_v2', 'deepseek_v3', 'eagle'))\ and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2a2c2523b725..494b838bd389 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -383,6 +383,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: model = _initialize_model(vllm_config=vllm_config) weights_to_load = {name for name, _ in model.named_parameters()} + if hasattr(model_config.hf_config, 'model_type') and model_config.hf_config.model_type == 'eagle': + model_config.model = 'deepseek-ai/DeepSeek-R1' loaded_weights = model.load_weights( self._get_all_weights(model_config, model)) # We only enable strict check for non-quantized models diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 000000000000..244a98ce14ab --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,188 @@ +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import is_pp_missing_parameter +from .deepseek_v2 import DeepseekV2DecoderLayer + +class DeepseekV3MTPSpeculator(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = "", mtp_layer_index: int = 0): + super().__init__() + config = vllm_config.model_config.hf_config + config.first_k_dense_replace = 0 + self.config = config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.quant_config = vllm_config.quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.shared_head = nn.ModuleDict({ + "head": ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=self.quant_config), + "norm": RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + }) + + layer_index = 61 + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + self.transformer = DeepseekV2DecoderLayer(config, f"{prefix}.layers.{layer_index}", quant_config=self.quant_config, cache_config=self.cache_config, model_config=self.model_config) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> List[torch.Tensor]: + if inputs_embeds is not None: + embedding = inputs_embeds + else: + embedding = self.embed_tokens(input_ids) + + h_normed = self.hnorm(previous_hidden_states) + e_normed = self.enorm(embedding) + + cat_in = torch.cat([e_normed, h_normed], dim=-1) # swapped from the paper + proj_out = self.eh_proj(cat_in) + + (mtp_hidden, mtp_residual) = self.transformer( + positions, + proj_out, + kv_cache=kv_caches[0], + attn_metadata=attn_metadata, + residual=None + ) + + return mtp_hidden + mtp_residual + # hidden_states = mtp_hidden + # hidden_states, _ = self.shared_head["norm"](hidden_states, mtp_residual) + # return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.shared_head["head"], self.shared_head["norm"](hidden_states), sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + assert self.config.num_nextn_predict_layers == 1 + layer_idx = 61 + if name.startswith(f"model.layers.{layer_idx}"): + name = name.replace(f"model.layers.{layer_idx}.", "") + if name.startswith("input_layernorm") or name.startswith("post_attention_layernorm") or name.startswith("mlp") or name.startswith("self_attn"): + name = "transformer." + name + else: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + breakpoint() + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + breakpoint() + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + breakpoint() + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fd0e58fa1458..27e4a1898c06 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -640,8 +640,8 @@ def forward( "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + # hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + residual class DeepseekV2ForCausalLM(nn.Module, SupportsPP): @@ -684,7 +684,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.lm_head, self.model.norm(hidden_states), sampling_metadata) return logits diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c2d0fae7056c..251040839dcf 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -184,6 +184,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "MedusaModel": ("medusa", "Medusa"), + "DeepseekV3MTPModel": ("deepseek_mtp", "DeepseekV3MTPSpeculator"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5474917a6fab..77f6811ef542 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -96,12 +96,22 @@ def sampler_output( # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) for _ in range(sample_len): + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] + if expanded_request.previous_hidden_states is not None: + assert hasattr(model_output, 'hidden_states') + seq_group_meta_with_hidden = [ + sg for sg in expanded_request.seq_group_metadata_list + if sg.do_sample + ] + expanded_request.previous_hidden_states = HiddenStates(model_output.hidden_states, seq_group_meta_with_hidden, expanded_request.previous_hidden_states.hidden_states) + self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8653bece8b5a..5dad6060a6aa 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -178,13 +178,13 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() and not draft_model_config.use_mla: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_model_config.hf_config.model_type == "eagle": - raise NotImplementedError( - "EAGLE does not support TP > 1 yet") + # if draft_model_config.hf_config.model_type == "eagle": + # raise NotImplementedError( + # "EAGLE does not support TP > 1 yet") allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fb5cc3ec0722..f7490f51c6d1 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -176,6 +176,11 @@ def get_config( ) -> PretrainedConfig: # Separate model folder from file path for GGUF models + if model == "DeepSeekV3MTP": + model_base_name = "deepseek-ai/DeepSeek-R1" + else: + model_base_name = model + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name @@ -183,9 +188,9 @@ def get_config( if config_format == ConfigFormat.AUTO: if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision): + model_base_name, HF_CONFIG_NAME, revision=revision): config_format = ConfigFormat.HF - elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, + elif file_or_path_exists(model_base_name, MISTRAL_CONFIG_NAME, revision=revision): config_format = ConfigFormat.MISTRAL else: @@ -193,7 +198,7 @@ def get_config( # raise an offline mode error to indicate to the user that they # don't have files cached and may need to go online. # This is conveniently triggered by calling file_exists(). - file_exists(model, + file_exists(model_base_name, HF_CONFIG_NAME, revision=revision, token=HF_TOKEN) @@ -202,13 +207,18 @@ def get_config( if config_format == ConfigFormat.HF: config_dict, _ = PretrainedConfig.get_config_dict( - model, + model_base_name, revision=revision, code_revision=code_revision, token=HF_TOKEN, **kwargs, ) + if model == "DeepSeekV3MTP": + config_dict["model_type"] = "eagle" + config_dict["num_hidden_layers"] = 1 + config_dict["architectures"] = ["DeepseekV3MTPModel"] + # Use custom model class if it's in our registry model_type = config_dict.get("model_type") if model_type in _CONFIG_REGISTRY: diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index b26aba66699f..8c6f2695c8c2 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -8,37 +8,104 @@ class EAGLEConfig(PretrainedConfig): model_type = "eagle" + # model_type = "deepseek_v3" + + keys_to_ignore_at_inference = ["past_key_values"] - def __init__(self, - model: Union[PretrainedConfig, dict, None] = None, - truncated_vocab_size: Optional[int] = None, - **kwargs): - - model_config = None if model is None else (AutoConfig.for_model( - **model) if isinstance(model, dict) else model) - - for k, v in kwargs.items(): - if k != "architectures" and k != "model_type" and hasattr( - model_config, k): - setattr(model_config, k, v) - - self.model = model_config - - if self.model is None: - self.truncated_vocab_size = None - else: - self.truncated_vocab_size = self.model.vocab_size if \ - truncated_vocab_size is None else truncated_vocab_size - - if "architectures" not in kwargs: - kwargs["architectures"] = ["EAGLEModel"] + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size = 2048, + num_hidden_layers=-1, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = 1 + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads - super().__init__(**kwargs) + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout - if self.model is not None: - for k, v in self.model.to_dict().items(): - if not hasattr(self, k): - setattr(self, k, v) + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) @classmethod def from_pretrained( @@ -46,6 +113,11 @@ def from_pretrained( pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs, ) -> "EAGLEConfig": - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) - return cls.from_dict(config_dict, **kwargs) + if pretrained_model_name_or_path == "DeepSeekV3MTP": + pretrained_model_name_or_path = "deepseek-ai/DeepSeek-R1" + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict["model_type"] = "eagle" + config_dict["num_hidden_layers"] = 1 + config_dict["architectures"] = ["DeepseekV3MTPModel"] + res = cls.from_dict(config_dict, **kwargs) + return res diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12baecde6e42..8f6308b3c890 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1650,6 +1650,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + previous_hidden_states: Optional[torch.Tensor] = None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1681,6 +1682,16 @@ def execute_model( graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) else: model_executable = self.model @@ -1707,6 +1718,10 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} + + if previous_hidden_states is not None: + seqlen_agnostic_kwargs["previous_hidden_states"] = (previous_hidden_states) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb..7db09a2f5d80 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -316,12 +316,14 @@ def _get_worker_input_from_broadcast( return None worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + + kwargs = extract_previous_hidden_states(broadcast_data) + broadcast_data.pop("previous_hidden_states", None) + model_input = ( self.model_runner.make_model_input_from_broadcasted_tensor_dict( broadcast_data)) - kwargs = extract_previous_hidden_states(broadcast_data) - return model_input, worker_input, kwargs def _get_driver_input_and_broadcast(