From 4e67f2ab05edcb9f0ae8c673e74ff7aeba76c9fe Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Wed, 29 May 2024 07:10:56 +0000 Subject: [PATCH 1/6] support cogvlm --- xinference/model/llm/__init__.py | 2 + xinference/model/llm/llm_family.json | 52 ++++ .../model/llm/llm_family_modelscope.json | 55 ++++ xinference/model/llm/pytorch/cogvlm2.py | 254 ++++++++++++++++++ xinference/model/llm/pytorch/core.py | 1 + 5 files changed, 364 insertions(+) create mode 100644 xinference/model/llm/pytorch/cogvlm2.py diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 196d7fd686..d3674b3795 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -113,6 +113,7 @@ def _install(): from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel from .pytorch.baichuan import BaichuanPytorchChatModel from .pytorch.chatglm import ChatglmPytorchChatModel + from .pytorch.cogvlm2 import CogVLM2Model from .pytorch.core import PytorchChatModel, PytorchModel from .pytorch.deepseek_vl import DeepSeekVLChatModel from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel @@ -159,6 +160,7 @@ def _install(): DeepSeekVLChatModel, InternVLChatModel, PytorchModel, + CogVLM2Model, ] ) if OmniLMMModel: # type: ignore diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 110d7c3fe4..9cc96ccdc7 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -6083,5 +6083,57 @@ "<|im_end|>" ] } +}, + { + "version": 1, + "context_length": 8192, + "model_name": "cogvlm2", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat", + "vision" + ], + "model_description": "CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 20, + "quantizations": [ + "none" + ], + "model_id": "THUDM/cogvlm2-llama3-chinese-chat-19B", + "model_revision": "d88b352bce5ee58a289b1ac8328553eb31efa2ef" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 20, + "quantizations": [ + "int4" + ], + "model_id": "THUDM/cogvlm2-llama3-chinese-chat-19B-{quantizations}", + "model_revision": "7863e362174f4718c2fe9cba4befd0b580a3194f" + } + ], + "prompt_style": { + "style_name": "LLAMA3", + "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", + "roles": [ + "user", + "assistant" + ], + "intra_message_sep": "\n\n", + "inter_message_sep": "<|eot_id|>", + "stop_token_ids": [ + 128001, + 128009 + ], + "stop": [ + "<|end_of_text|>", + "<|eot_id|>" + ] + } } ] diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 3bae8fe5e0..a1fc2be061 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -3739,5 +3739,60 @@ "<|im_end|>" ] } +}, + { + "version": 1, + "context_length": 8192, + "model_name": "cogvlm2", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat", + "vision" + ], + "model_description": "CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 20, + "quantizations": [ + "none" + ], + "model_hub": "modelscope", + + "model_id": "ZhipuAI/cogvlm2-llama3-chinese-chat-19B", + "model_revision": "master" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 20, + "quantizations": [ + "int4" + ], + "model_hub": "modelscope", + "model_id": "ZhipuAI/cogvlm2-llama3-chinese-chat-19B-{quantization}", + "model_revision": "master" + } + ], + "prompt_style": { + "style_name": "LLAMA3", + "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", + "roles": [ + "user", + "assistant" + ], + "intra_message_sep": "\n\n", + "inter_message_sep": "<|eot_id|>", + "stop_token_ids": [ + 128001, + 128009 + ], + "stop": [ + "<|end_of_text|>", + "<|eot_id|>" + ] + } } ] diff --git a/xinference/model/llm/pytorch/cogvlm2.py b/xinference/model/llm/pytorch/cogvlm2.py new file mode 100644 index 0000000000..fb02de3f4b --- /dev/null +++ b/xinference/model/llm/pytorch/cogvlm2.py @@ -0,0 +1,254 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import logging +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from io import BytesIO +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import requests +import torch +from PIL import Image + +from ....model.utils import select_device +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + Completion, + CompletionChoice, + CompletionUsage, +) +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from .core import PytorchChatModel, PytorchGenerateConfig + +logger = logging.getLogger(__name__) + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class CogVLM2Model(PytorchChatModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._tokenizer = None + self._model = None + + @classmethod + def match( + cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str + ) -> bool: + family = model_family.model_family or model_family.model_name + if "cogvlm" in family.lower(): + return True + return False + + def load(self, **kwargs): + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.generation import GenerationConfig + + device = self._pytorch_model_config.get("device", "auto") + self._device = select_device(device) + self._torch_type = ( + torch.bfloat16 + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True, + ) + + self._model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=self._torch_type, + trust_remote_code=True, + device_map="auto", + ).eval() + + # Specify hyperparameters for generation + self._model.generation_config = GenerationConfig.from_pretrained( + self.model_path, + trust_remote_code=True, + ) + + def _message_content_to_cogvlm2(self, content): + def _load_image(_url): + if _url.startswith("data:"): + logging.info("Parse url by base64 decoder.") + # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images + # e.g. f"data:image/jpeg;base64,{base64_image}" + _type, data = _url.split(";") + _, ext = _type.split("/") + data = data[len("base64,") :] + data = base64.b64decode(data.encode("utf-8")) + return Image.open(BytesIO(data)).convert("RGB") + else: + try: + response = requests.get(_url) + except requests.exceptions.MissingSchema: + return Image.open(_url).convert("RGB") + else: + return Image.open(BytesIO(response.content)).convert("RGB") + + if not isinstance(content, str): + texts = [] + image_urls = [] + for c in content: + c_type = c.get("type") + if c_type == "text": + texts.append(c["text"]) + elif c_type == "image_url": + image_urls.append(c["image_url"]["url"]) + image_futures = [] + with ThreadPoolExecutor() as executor: + for image_url in image_urls: + fut = executor.submit(_load_image, image_url) + image_futures.append(fut) + images = [fut.result() for fut in image_futures] + text = " ".join(texts) + if len(images) == 0: + return text, None + elif len(images) == 1: + return text, images + else: + raise RuntimeError( + "Only one image per message is supported by CogVLM2." + ) + return content, None + + def _history_content_to_intern( + self, system_prompt: str, chat_history: List[ChatCompletionMessage] + ): + def _image_to_piexl_values(image): + if image.startswith("data:"): + logging.info("Parse url by base64 decoder.") + # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images + # e.g. f"data:image/jpeg;base64,{base64_image}" + _type, data = image.split(";") + _, ext = _type.split("/") + data = data[len("base64,") :] + data = base64.b64decode(data.encode("utf-8")) + return Image.open(BytesIO(data)).convert("RGB") + else: + try: + response = requests.get(image) + except requests.exceptions.MissingSchema: + return Image.open(image).convert("RGB") + else: + return Image.open(BytesIO(response.content)).convert("RGB") + + query = system_prompt + history: List[Tuple] = [] + pixel_values = None + for i in range(0, len(chat_history), 2): + user = chat_history[i]["content"] + if isinstance(user, List): + for content in user: + c_type = content.get("type") + if c_type == "text": + user = content["text"] + elif c_type == "image_url" and not pixel_values: + pixel_values = _image_to_piexl_values( + content["image_url"]["url"] + ) + assistant = chat_history[i + 1]["content"] + query = query + f" USER: {user} ASSISTANT:" + history.append((query, assistant)) + query = query + f" {assistant}" + return query, history, [pixel_values] + + def chat( + self, + prompt: Union[str, List[Dict]], + system_prompt: Optional[str] = None, + chat_history: Optional[List[ChatCompletionMessage]] = None, + generate_config: Optional[PytorchGenerateConfig] = None, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + system_prompt = system_prompt if system_prompt else "" + if generate_config and generate_config.pop("stream"): + raise Exception( + f"Chat with model {self.model_family.model_name} does not support stream." + ) + + sanitized_config = { + "pad_token_id": 128002, + "max_new_tokens": generate_config.get("max_tokens", 512) + if generate_config + else 512, + } + + content, image = self._message_content_to_cogvlm2(prompt) + + history = [] + query = "" + history_image = None + if chat_history: + query, history, history_image = self._history_content_to_intern( + system_prompt, chat_history + ) + + if image and history_image: + history = [] + query = system_prompt + f" USER: {content} ASSISTANT:" + else: + image = image if image else history_image + query = query + f" USER: {content} ASSISTANT:" + + input_by_model = self._model.build_conversation_input_ids( + self._tokenizer, + query=query, + history=history, + images=image, + template_version="chat", + ) + + inputs = { + "input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device), + "token_type_ids": input_by_model["token_type_ids"] + .unsqueeze(0) + .to(self._device), + "attention_mask": input_by_model["attention_mask"] + .unsqueeze(0) + .to(self._device), + "images": [ + [input_by_model["images"][0].to(self._device).to(self._torch_type)] + ] + if image is not None + else None, + } + with torch.no_grad(): + outputs = self._model.generate(**inputs, **sanitized_config) + outputs = outputs[:, inputs["input_ids"].shape[1] :] + response = self._tokenizer.decode(outputs[0]) + response = response.split("<|end_of_text|>")[0] + + chunk = Completion( + id=str(uuid.uuid1()), + object="text_completion", + created=int(time.time()), + model=self.model_uid, + choices=[ + CompletionChoice( + index=0, text=response, finish_reason="stop", logprobs=None + ) + ], + usage=CompletionUsage( + prompt_tokens=-1, completion_tokens=-1, total_tokens=-1 + ), + ) + return self._to_chat_completion(chunk) diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index fb43d65c9f..f1598358ca 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -62,6 +62,7 @@ "deepseek-vl-chat", "internvl-chat", "mini-internvl-chat", + "cogvlm2", ] From 43e54f005e59fbaf066937e22a58e25cb2d89286 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Wed, 29 May 2024 07:55:46 +0000 Subject: [PATCH 2/6] support cogvlm --- xinference/model/llm/pytorch/cogvlm2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xinference/model/llm/pytorch/cogvlm2.py b/xinference/model/llm/pytorch/cogvlm2.py index fb02de3f4b..798754cffa 100644 --- a/xinference/model/llm/pytorch/cogvlm2.py +++ b/xinference/model/llm/pytorch/cogvlm2.py @@ -77,6 +77,7 @@ def load(self, **kwargs): self.model_path, torch_dtype=self._torch_type, trust_remote_code=True, + low_cpu_mem_usage=True, device_map="auto", ).eval() From 5b9eaa0a17af8f35e226a848bda870fbd4705cb5 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Thu, 30 May 2024 02:01:21 +0000 Subject: [PATCH 3/6] rename func --- xinference/model/llm/pytorch/cogvlm2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/pytorch/cogvlm2.py b/xinference/model/llm/pytorch/cogvlm2.py index 798754cffa..445f11627d 100644 --- a/xinference/model/llm/pytorch/cogvlm2.py +++ b/xinference/model/llm/pytorch/cogvlm2.py @@ -44,6 +44,8 @@ class CogVLM2Model(PytorchChatModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._torch_type = None + self._device = None self._tokenizer = None self._model = None @@ -132,7 +134,7 @@ def _load_image(_url): ) return content, None - def _history_content_to_intern( + def _history_content_to_cogvlm2( self, system_prompt: str, chat_history: List[ChatCompletionMessage] ): def _image_to_piexl_values(image): @@ -199,7 +201,7 @@ def chat( query = "" history_image = None if chat_history: - query, history, history_image = self._history_content_to_intern( + query, history, history_image = self._history_content_to_cogvlm2( system_prompt, chat_history ) From 659ef863198ac41febcd44c432bec55ea92572f8 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Fri, 31 May 2024 02:17:29 +0000 Subject: [PATCH 4/6] update doc --- doc/source/models/builtin/llm/cogvlm2.rst | 47 +++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 doc/source/models/builtin/llm/cogvlm2.rst diff --git a/doc/source/models/builtin/llm/cogvlm2.rst b/doc/source/models/builtin/llm/cogvlm2.rst new file mode 100644 index 0000000000..e907e1d157 --- /dev/null +++ b/doc/source/models/builtin/llm/cogvlm2.rst @@ -0,0 +1,47 @@ +.. _models_llm_cogvlm2: + +======================================== +cogvlm2 +======================================== + +- **Context Length:** 8192 +- **Model Name:** cogvlm2 +- **Languages:** en, zh +- **Abilities:** chat, vision +- **Description:** CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models. + +Specifications +^^^^^^^^^^^^^^ + + +Model Spec 1 (pytorch, 20 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 20 +- **Quantizations:** none +- **Engines**: Transformers +- **Model ID:** THUDM/cogvlm2-llama3-chinese-chat-19B +- **Model Hubs**: `Hugging Face `__, `ModelScope `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name cogvlm2 --size-in-billions 20 --model-format pytorch --quantization ${quantization} + + +Model Spec 2 (pytorch, 20 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 20 +- **Quantizations:** int4 +- **Engines**: Transformers +- **Model ID:** THUDM/cogvlm2-llama3-chinese-chat-19B-{quantizations} +- **Model Hubs**: `Hugging Face `__, `ModelScope `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name cogvlm2 --size-in-billions 20 --model-format pytorch --quantization ${quantization} + From 44a75d42d1efc95d0823bf5db65c9790c74e1df2 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Fri, 31 May 2024 02:18:50 +0000 Subject: [PATCH 5/6] update doc1 --- doc/source/models/builtin/llm/index.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/source/models/builtin/llm/index.rst b/doc/source/models/builtin/llm/index.rst index b14e7b4609..ce9c4b01c7 100644 --- a/doc/source/models/builtin/llm/index.rst +++ b/doc/source/models/builtin/llm/index.rst @@ -126,6 +126,11 @@ The following is a list of built-in LLM in Xinference: - 8194 - CodeShell is a multi-language code LLM developed by the Knowledge Computing Lab of Peking University. + * - :ref:`cogvlm2 ` + - chat, vision + - 8192 + - CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models. + * - :ref:`deepseek ` - generate - 4096 @@ -236,11 +241,6 @@ The following is a list of built-in LLM in Xinference: - 8192 - The Llama 3 instruction tuned models are optimized for dialogue use cases and outperform many of the available open source chat models on common industry benchmarks.. - * - :ref:`mini-internvl-chat ` - - chat, vision - - 32768 - - InternVL 1.5 is an open-source multimodal large language model (MLLM) to bridge the capability gap between open-source and proprietary commercial models in multimodal understanding. - * - :ref:`minicpm-2b-dpo-bf16 ` - chat - 4096 @@ -550,6 +550,8 @@ The following is a list of built-in LLM in Xinference: codeshell-chat + cogvlm2 + deepseek deepseek-chat @@ -594,8 +596,6 @@ The following is a list of built-in LLM in Xinference: llama-3-instruct - mini-internvl-chat - minicpm-2b-dpo-bf16 minicpm-2b-dpo-fp16 From f27cc0c7224fc2a4ca92b1f1414db7e264aa74f6 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Fri, 31 May 2024 03:23:34 +0000 Subject: [PATCH 6/6] debug --- xinference/model/llm/pytorch/cogvlm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/pytorch/cogvlm2.py b/xinference/model/llm/pytorch/cogvlm2.py index 445f11627d..c3cc31b23a 100644 --- a/xinference/model/llm/pytorch/cogvlm2.py +++ b/xinference/model/llm/pytorch/cogvlm2.py @@ -183,7 +183,7 @@ def chat( generate_config: Optional[PytorchGenerateConfig] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: system_prompt = system_prompt if system_prompt else "" - if generate_config and generate_config.pop("stream"): + if generate_config and generate_config.get("stream"): raise Exception( f"Chat with model {self.model_family.model_name} does not support stream." )