diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f56679c3c6d00..50cae1041bc8f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -137,6 +137,10 @@ Decoder-only Language Models - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + * - :code:`PersimmonForCausalLM` + - Persimmon + - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. @@ -178,6 +182,10 @@ Vision Language Models - Models - Example HuggingFace Models - :ref:`LoRA ` + * - :code:`FuyuForCausalLM` + - Fuyu + - :code:`adept/fuyu-8b` etc. + - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. diff --git a/examples/fuyu_example.py b/examples/fuyu_example.py new file mode 100644 index 0000000000000..c92b8fb4bc286 --- /dev/null +++ b/examples/fuyu_example.py @@ -0,0 +1,31 @@ +import requests +from PIL import Image + +from vllm import LLM, SamplingParams + + +def run_fuyu(): + llm = LLM(model="adept/fuyu-8b", max_model_len=4096) + + # single-image prompt + prompt = "What is the highest life expectancy at of male?\n" + url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png" + image = Image.open(requests.get(url, stream=True).raw) + sampling_params = SamplingParams(temperature=0, max_tokens=64) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": { + "image": image + }, + }, + sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + run_fuyu() diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py new file mode 100644 index 0000000000000..672470acb77e6 --- /dev/null +++ b/tests/models/test_fuyu.py @@ -0,0 +1,142 @@ +from typing import List, Optional, Tuple, Type + +import pytest + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": "What's the content of the image?\n", # noqa: E501 + "cherry_blossom": "What is the season?\n", + "boardwalk": "What's in this image?\n", +}) + +models = ["adept/fuyu-8b"] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]]): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str.lstrip() + "|ENDOFTEXT|" + + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=2560, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=vllm_images) + for prompts, vllm_images in inputs_per_image + ] + + with hf_runner(model, dtype=dtype) as hf_model: + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + eos_token_id = hf_model.processor.tokenizer.eos_token_id + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=hf_images, + eos_token_id=eos_token_id) + for prompts, hf_images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +target_dtype = "half" +if is_cpu(): + target_dtype = "bfloat16" + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [0.25], + # Single-scale, batched + [0.25, 0.25, 0.25], + # Multi-scale + [0.25, 0.2, 0.15], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 096e3f4724014..87508a1168e0c 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -23,6 +23,7 @@ "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), @@ -49,6 +50,7 @@ "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py new file mode 100644 index 0000000000000..fdea8ee30ce68 --- /dev/null +++ b/vllm/model_executor/models/fuyu.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Fuyu model.""" +import math +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from PIL import Image +from transformers import FuyuConfig, FuyuImageProcessor + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.persimmon import PersimmonForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import (cached_get_image_processor, + cached_get_tokenizer) +from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData + +from .interfaces import SupportsVision +from .utils import merge_vision_embeddings + +logger = init_logger(__name__) + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 71011 +_NEWLINE_TOKEN_ID = 71019 + +MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080 +MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 + + +class FuyuImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: + (batch_size, num_patches, patch_size_x * patch_size_y * num_channels) + """ + + +def _calculate_num_image_tokens( + height: int, + width: int, +) -> Tuple[int, int]: + """ + calculate number of image tokens needed for a given image size + The expected Fuyu image prompts is in format: + (image_token * ncols + newline_token) * nrows + args: + image_size: Tuple[int, int] - (width, height) of the image + returns: + ncols: int - number of image tokens in x direction + nrows: int - number of image tokens in y direction + """ + ncol = math.ceil(width / 30) + nrow = math.ceil(height / 30) + return ncol, nrow + + +def get_max_fuyu_image_feature_size(): + + return _calculate_num_image_tokens( + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + ) + + +def get_max_fuyu_image_tokens(ctx: InputContext): + ncol, nrow = get_max_fuyu_image_feature_size() + return (ncol + 1) * nrow + + +def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int): + ncol, nrow = get_max_fuyu_image_feature_size() + image_feature_size = get_max_fuyu_image_tokens(ctx) + + token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_fuyu( + image_width: int, + image_height: int, +): + image = Image.new("RGB", (image_width, image_height), color=0) + return {"image": image} + + +def dummy_data_for_fuyu(ctx: InputContext, seq_len: int): + seq_data = dummy_seq_data_for_fuyu(ctx, seq_len) + mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH, + MAX_IMAGE_FEATURE_SIZE_HEIGHT) + return seq_data, mm_data + + +def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, + data: Image.Image): + image_encoding = image_processor.preprocess(data, return_tensors="pt") + batch_images = torch.stack([img[0] for img in image_encoding["images"] + ]).unsqueeze(1) + image_unpadded_heights = torch.tensor( + image_encoding["image_unpadded_heights"]) + image_unpadded_widths = torch.tensor( + image_encoding["image_unpadded_widths"]) + + batch_size = len(image_encoding["images"]) + image_present = torch.ones(batch_size, 1, 1) + model_image_input = image_processor.preprocess_with_tokenizer_info( + image_input=batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=_IMAGE_TOKEN_ID, + image_newline_id=_NEWLINE_TOKEN_ID, + variable_sized=True, + ) + return model_image_input + + +def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + image_data = multi_modal_data["image"] + new_multi_modal_data = {} + # process image data + if isinstance(image_data, Image.Image): + # Fuyu's image_processor can also finish token padding + image_processor: FuyuImageProcessor = cached_get_image_processor( + model_config.model) + + model_image_input = _fuyu_image_preprocess(image_processor, image_data) + image_patches = torch.stack([ + image_patch[0] + for image_patch in model_image_input["image_patches"] + ]) + new_multi_modal_data["image"] = image_patches + + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + # process prompts + prompt = llm_inputs["prompt"] + prompt_token_ids = llm_inputs["prompt_token_ids"] + tokenizer = cached_get_tokenizer(model_config.model) + # dim0 is batch_size, dim1 is subseq_size which will always be 1 + image_input_ids: List[List[ + torch.Tensor]] = model_image_input["image_input_ids"] + image_input_ids = image_input_ids[0][0].tolist() + bos_token = tokenizer.encode("", add_special_tokens=False)[1:] + boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:] + + new_prompt = prompt + "\x04" + new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ + 1:] + boa_token + + return LLMInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) + + +def input_mapper_for_fuyu(ctx: InputContext, data: object): + model_config = ctx.model_config + if isinstance(data, Image.Image): + # Fuyu's image_processor can also finish token padding + image_processor: FuyuImageProcessor = cached_get_image_processor( + model_config.model) + + model_image_input = _fuyu_image_preprocess(image_processor, data) + data = torch.stack([ + image_patch[0] + for image_patch in model_image_input["image_patches"] + ]) + + # image has been processed with prompt in input processor + return MultiModalInputs({"image_patches": data}) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) +@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) +class FuyuForCausalLM(nn.Module, SupportsVision): + + def __init__(self, + config: FuyuConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.multimodal_config = multimodal_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.image_token_id = _IMAGE_TOKEN_ID + self.image_feature_size = config.patch_size**2 * config.num_channels + + self.vision_embed_tokens = ColumnParallelLinear( + self.image_feature_size, + config.hidden_size, + quant_config=quant_config, + ) + self.language_model = PersimmonForCausalLM(config, + cache_config=cache_config, + quant_config=quant_config) + + def _parse_and_validate_image_input(self, **kwargs: object): + image_patches = kwargs.pop("image_patches", None) + + if isinstance(image_patches, torch.Tensor): + expected_feature_size = self.image_feature_size + if image_patches.size(-1) != expected_feature_size: + raise ValueError( + f"Expected image patches to have the last dimension of " + f"{expected_feature_size}, got {image_patches.size(-1)}") + image_patches = image_patches.to( + self.vision_embed_tokens.weight.dtype) + return FuyuImagePixelInputs(type="pixel_values", + data=image_patches) + return None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ): + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings, _ = self.vision_embed_tokens( + image_input["data"]) + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, + vision_embeddings, + self.image_token_id) + + else: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.language_model.logits_processor( + self.language_model.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.language_model.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Fuyu's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py new file mode 100644 index 0000000000000..bc38d4421b79e --- /dev/null +++ b/vllm/model_executor/models/persimmon.py @@ -0,0 +1,333 @@ +# coding=utf-8 +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only persimmon model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PersimmonConfig +from transformers.activations import ReLUSquaredActivation + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput + + +class PersimmonMLP(nn.Module): + + def __init__(self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, + config.hidden_size, + quant_config=quant_config) + self.act = ReLUSquaredActivation() + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class PersimmonAttention(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + tensor_parallel_world_size = get_tensor_model_parallel_world_size() + + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tensor_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + assert (self.head_dim * self.total_num_heads) == self.hidden_size + assert self.total_num_heads % tensor_parallel_world_size == 0 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.is_qk_layernorm = config.qk_layernorm + + if self.is_qk_layernorm: + self.q_layernorm = nn.LayerNorm(self.head_dim) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + def _split_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads, self.head_dim) + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads * self.head_dim) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # [seq_length, 3 x hidden_size] + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + if self.is_qk_layernorm: + # [seq_length, num_heads, head_dim] + q = self._split_heads(q) + k = self._split_heads(k) + + q = self.q_layernorm(q) + k = self.k_layernorm(k) + + q = self._merge_heads(q) + k = self._merge_heads(k) + + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.dense(attn_output) + return output + + +class PersimmonDecoderLayer(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = PersimmonMLP(config, quant_config=quant_config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = hidden_states + return outputs + + +class PersimmonModel(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + PersimmonDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class PersimmonForCausalLM(nn.Module): + + def __init__(self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.model = PersimmonModel(config, + cache_config=cache_config, + quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=False) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Persimmon's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)