Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Peter/cogvlm #175

Merged
merged 12 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker/dockerfiles/Dockerfile.onnx.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ COPY requirements/requirements.sam.txt \
requirements/requirements.waf.txt \
requirements/requirements.gaze.txt \
requirements/requirements.doctr.txt \
requirements/requirements.cog.txt \
requirements/_requirements.txt \
./

Expand All @@ -32,6 +33,7 @@ RUN pip3 install --upgrade pip && pip3 install \
-r requirements.waf.txt \
-r requirements.gaze.txt \
-r requirements.doctr.txt \
-r requirements.cog.txt \
--upgrade \
&& rm -rf ~/.cache/pip

Expand Down
89 changes: 89 additions & 0 deletions examples/cogvlm/cog_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import base64
import asyncio
import aiohttp
import os
from PIL import Image
import requests

PORT = 9001
API_KEY = os.environ["API_KEY"]
IMAGE_PATH = "image.jpg"


def encode_bas64(image_path):
with open(image_path, "rb") as image:
x = image.read()
image_string = base64.b64encode(x)

return image_string.decode("ascii")


async def do_cog_request(session):
api_key = API_KEY
prompt = (
"The player on the left's name is Moky."
" What round of the tournament is he in? Answer in one word."
)

print(f"Starting")
infer_payload = {
"image": {
"type": "base64",
"value": encode_bas64(IMAGE_PATH),
},
"api_key": api_key,
"prompt": prompt,
}
async with session.post(
f"http://localhost:{PORT}/llm/cogvlm",
json=infer_payload,
) as response:
if response.status != 200:
print(response.status)
print(await response.json())
raise RuntimeError
resp = await response.json()
res = resp["response"]
print(resp)
infer_payload = {
"image": {
"type": "base64",
"value": encode_bas64(IMAGE_PATH),
},
"api_key": api_key,
"prompt": "What is the name of the player on the left?",
"history": [(prompt, res)],
}
async with session.post(
f"http://localhost:{PORT}/llm/cogvlm",
json=infer_payload,
) as response:
if response.status != 200:
print(response.status)
print(await response.json())
raise RuntimeError
resp = await response.json()
res = resp["response"]
print(resp)


async def main():
import time

start = time.perf_counter()
connector = aiohttp.TCPConnector(limit=100, limit_per_host=100)
async with aiohttp.ClientSession(read_timeout=0, connector=connector) as session:
await do_cog_request(session)
total = time.perf_counter() - start
print(f"Total time: {total:.2f} seconds")


if __name__ == "__main__":
Image.open(
requests.get(
"https://source.roboflow.com/ACrZ7Hz8DRUB1NBMMtDoQK84Hf22/0qUjAGRJQWWhT5j9hUOG/original.jpg",
stream=True,
).raw
).convert("RGB").save(IMAGE_PATH)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
44 changes: 44 additions & 0 deletions inference/core/entities/requests/cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Dict, List, Optional, Tuple, Union

from pydantic import Field, validator

from inference.core.entities.requests.inference import (
BaseRequest,
InferenceRequestImage,
)
from inference.core.env import COG_VERSION_ID


class CogVLMInferenceRequest(BaseRequest):
"""Request for CogVLM inference.

Attributes:
api_key (Optional[str]): Roboflow API Key.
cog_version_id (Optional[str]): The version ID of CLIP to be used for this request.
"""

cogvlm_version_id: Optional[str] = Field(
default=COG_VERSION_ID,
example="cogvlm-chat-hf",
description="The version ID of CogVLM to be used for this request. See the huggingface model repo at THUDM.",
)
model_id: Optional[str] = Field()
image: InferenceRequestImage = Field(
description="Image for CogVLM to look at. Use prompt to specify what you want it to do with the image."
)
prompt: str = Field(
description="Text to be passed to CogVLM. Use to prompt it to describe an image or provide only text to chat with the model.",
example="Describe this image.",
)
history: Optional[List[Tuple[str, str]]] = Field(
description="Optional chat history, formatted as a list of 2-tuples where the first entry is the user prompt"
" and the second entry is the generated model response"
)

@validator("model_id", always=True)
def validate_model_id(cls, value, values):
if value is not None:
return value
if values.get("cogvlm_version_id") is None:
return None
return f"cogvlm/{values['cogvlm_version_id']}"
10 changes: 10 additions & 0 deletions inference/core/entities/responses/cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional

from pydantic import BaseModel, Field


class CogVLMResponse(BaseModel):
response: str = Field(description="Text generated by CogVLM")
time: Optional[float] = Field(
description="The time in seconds it took to produce the response including preprocessing"
)
6 changes: 6 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
# AWS secret access key, default is None
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", None)

COG_LOAD_4BIT = str2bool(os.getenv("COG_LOAD_4BIT", True))
COG_LOAD_8BIT = str2bool(os.getenv("COG_LOAD_8BIT", False))
COG_VERSION_ID = os.getenv("COG_VERSION_ID", "cogvlm-chat-hf")
# CLIP version ID, default is "ViT-B-16"
CLIP_VERSION_ID = os.getenv("CLIP_VERSION_ID", "ViT-B-16")

Expand Down Expand Up @@ -83,6 +86,8 @@
# Flag to enable DocTR core model, default is True
CORE_MODEL_DOCTR_ENABLED = str2bool(os.getenv("CORE_MODEL_DOCTR_ENABLED", True))

CORE_MODEL_COGVLM_ENABLED = str2bool(os.getenv("CORE_MODEL_COGVLM_ENABLED", True))

# ID of host device, default is None
DEVICE_ID = os.getenv("DEVICE_ID", None)

Expand Down Expand Up @@ -227,6 +232,7 @@
# SAM version ID, default is "vit_h"
SAM_VERSION_ID = os.getenv("SAM_VERSION_ID", "vit_h")


# Device ID, default is "sample-device-id"
INFERENCE_SERVER_ID = os.getenv("INFERENCE_SERVER_ID", None)

Expand Down
43 changes: 43 additions & 0 deletions inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ClipImageEmbeddingRequest,
ClipTextEmbeddingRequest,
)
from inference.core.entities.requests.cog import CogVLMInferenceRequest
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest
from inference.core.entities.requests.gaze import GazeDetectionInferenceRequest
from inference.core.entities.requests.inference import (
Expand All @@ -38,6 +39,7 @@
ClipCompareResponse,
ClipEmbeddingResponse,
)
from inference.core.entities.responses.cog import CogVLMResponse
from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse
from inference.core.entities.responses.gaze import GazeDetectionInferenceResponse
from inference.core.entities.responses.inference import (
Expand All @@ -60,6 +62,7 @@
from inference.core.env import (
ALLOW_ORIGINS,
CORE_MODEL_CLIP_ENABLED,
CORE_MODEL_COGVLM_ENABLED,
CORE_MODEL_DOCTR_ENABLED,
CORE_MODEL_GAZE_ENABLED,
CORE_MODEL_SAM_ENABLED,
Expand Down Expand Up @@ -341,6 +344,7 @@ def load_core_model(
Returns:
The DocTR model ID.
"""
load_cogvlm_model = partial(load_core_model, core_model="cogvlm")

@app.get(
"/info",
Expand Down Expand Up @@ -838,6 +842,45 @@ async def gaze_detection(
trackUsage(gaze_model_id, actor)
return response

if CORE_MODEL_COGVLM_ENABLED:

@app.post(
"/llm/cogvlm",
response_model=CogVLMResponse,
summary="CogVLM",
description="Run the CogVLM model to chat or describe an image.",
)
@with_route_exceptions
async def cog_vlm(
inference_request: CogVLMInferenceRequest,
request: Request,
api_key: Optional[str] = Query(
None,
description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
),
):
"""
Chat with CogVLM or ask it about an image. Multi-image requests not currently supported.

Args:
inference_request (M.CogVLMInferenceRequest): The request containing the prompt and image to be described.
api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
request (Request, default Body()): The HTTP request.

Returns:
M.CogVLMResponse: The model's text response
"""
cog_model_id = load_cogvlm_model(inference_request, api_key=api_key)
response = await self.model_manager.infer_from_request(
cog_model_id, inference_request
)
if LAMBDA:
actor = request.scope["aws.event"]["requestContext"][
"authorizer"
]["lambda"]["actor"]
trackUsage(cog_model_id, actor)
return response

if LEGACY_ROUTE_ENABLED:
# Legacy object detection inference path for backwards compatability
@app.post(
Expand Down
1 change: 1 addition & 0 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"sam": ("embed", "sam"),
"gaze": ("gaze", "l2cs"),
"doctr": ("ocr", "doctr"),
"cogvlm": ("llm", "cogvlm"),
}

STUB_VERSION_ID = "0"
Expand Down
5 changes: 5 additions & 0 deletions inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
except:
pass

try:
from inference.models.cogvlm import CogVLM
except:
pass

from inference.models.vit import VitClassification
from inference.models.yolact import YOLACT
from inference.models.yolov5 import YOLOv5InstanceSegmentation, YOLOv5ObjectDetection
Expand Down
1 change: 1 addition & 0 deletions inference/models/cogvlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from inference.models.cogvlm.cog import CogVLM
97 changes: 97 additions & 0 deletions inference/models/cogvlm/cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from time import perf_counter
from typing import Any, List, Tuple, Union

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

from inference.core.entities.requests.cog import CogVLMInferenceRequest
from inference.core.entities.responses.cog import CogVLMResponse
from inference.core.env import (
API_KEY,
COG_LOAD_4BIT,
COG_LOAD_8BIT,
COG_VERSION_ID,
MODEL_CACHE_DIR,
)
from inference.core.models.base import Model, PreprocessReturnMetadata
from inference.core.utils.image_utils import load_image_rgb

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class CogVLM(Model):
def __init__(self, model_id=f"cogvlm/{COG_VERSION_ID}", **kwargs):
self.model_id = model_id
self.endpoint = model_id
self.api_key = API_KEY
self.dataset_id, self.version_id = model_id.split("/")
if COG_LOAD_4BIT and COG_LOAD_8BIT:
raise ValueError(
"Only one of environment variable `COG_LOAD_4BIT` or `COG_LOAD_8BIT` can be true"
)
self.cache_dir = os.path.join(MODEL_CACHE_DIR, self.endpoint)
with torch.inference_mode():
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
self.model = AutoModelForCausalLM.from_pretrained(
f"THUDM/{self.version_id}",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
load_in_4bit=COG_LOAD_4BIT,
load_in_8bit=COG_LOAD_8BIT,
cache_dir=self.cache_dir,
).eval()

def preprocess(
self, image: Any, **kwargs
) -> Tuple[Image.Image, PreprocessReturnMetadata]:
pil_image = Image.fromarray(load_image_rgb(image))

return pil_image, PreprocessReturnMetadata({})

def postprocess(
self,
predictions: Tuple[str],
preprocess_return_metadata: PreprocessReturnMetadata,
**kwargs,
) -> Any:
return predictions[0]

def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
images = [image_in]
if history is None:
history = []
built_inputs = self.model.build_conversation_input_ids(
self.tokenizer, query=prompt, history=history, images=images
) # chat mode
inputs = {
"input_ids": built_inputs["input_ids"].unsqueeze(0).to(DEVICE),
"token_type_ids": built_inputs["token_type_ids"].unsqueeze(0).to(DEVICE),
"attention_mask": built_inputs["attention_mask"].unsqueeze(0).to(DEVICE),
"images": [[built_inputs["images"][0].to(DEVICE).to(torch.float16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.inference_mode():
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
text = self.tokenizer.decode(outputs[0])
if text.endswith("</s>"):
text = text[:-4]
return (text,)

def infer_from_request(self, request: CogVLMInferenceRequest) -> CogVLMResponse:
t1 = perf_counter()
text = self.infer(**request.dict())
response = CogVLMResponse(response=text)
response.time = perf_counter() - t1
return response


if __name__ == "__main__":
m = CogVLM()
m.infer()
Loading
Loading