Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Continuous batching supports vision model ability #1724

Merged
merged 11 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class _OutOfMemoryError(Exception):
OutOfMemoryError = _OutOfMemoryError


XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = ["qwen-vl-chat", "cogvlm2", "glm-4v"]


def request_limit(fn):
"""
Used by ModelActor.
Expand Down Expand Up @@ -268,11 +271,25 @@ def allow_batching(self) -> bool:

model_ability = self._model_description.get("model_ability", [])

return (
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
and isinstance(self._model, PytorchModel)
and "vision" not in model_ability
condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
self._model, PytorchModel
)
if condition and "vision" in model_ability:
if (
self._model.model_family.model_name
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
or self._model.model_family.model_family
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
):
return True
else:
logger.warning(
f"Currently for multimodal models, "
f"xinference only supports {', '.join(XINFERENCE_BATCHING_ALLOWED_VISION_MODELS)} for batching. "
f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
)
return False
return condition

async def load(self):
self._model.load()
Expand Down
2 changes: 2 additions & 0 deletions xinference/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
# Record error message when this request has error.
# Must set stopped=True when this field is set.
self.error_msg: Optional[str] = None
# For compatibility. Record some extra parameters for some special cases.
self.extra_kwargs = {}

# check the integrity of args passed upstream
self._check_args()
Expand Down
7 changes: 4 additions & 3 deletions xinference/core/tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def run_internal(self):
assert isinstance(res, dict)
choices = res["choices"]
assert isinstance(choices, list)
choice = choices[0]["text"]
assert isinstance(choice, str)
assert len(choice) > 0
choice = choices[0]["message"]
assert isinstance(choice, dict)
content = choice["content"]
assert len(content) > 0


class InferenceThreadWithError(InferenceThread):
Expand Down
12 changes: 11 additions & 1 deletion xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@
"none"
],
"model_id": "THUDM/glm-4v-9b",
"model_revision": "e8b84fefc07e58a90c8489337675573fda95e289"
"model_revision": "6c2e4732db8443f64a48d5af04b74425a7d169c4"
}
],
"prompt_style": {
Expand Down Expand Up @@ -5809,6 +5809,16 @@
"roles": [
"user",
"assistant"
],
"stop_token_ids": [
151643,
151644,
151645
],
"stop": [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>"
]
}
},
Expand Down
10 changes: 10 additions & 0 deletions xinference/model/llm/llm_family_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -3402,6 +3402,16 @@
"roles": [
"user",
"assistant"
],
"stop_token_ids": [
151643,
151644,
151645
],
"stop": [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>"
]
}
},
Expand Down
8 changes: 0 additions & 8 deletions xinference/model/llm/pytorch/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,6 @@ def _stream_generator():
),
)

@staticmethod
def require_attention_mask():
"""
GLM4 needs to use attention mask and position ids during inference.
Otherwise, the inference result would be not available.
"""
return True

def prepare_sanitize_generate_config(self, req: InferenceRequest):
"""
Set temperature and top_p to 0.8 by default
Expand Down
227 changes: 206 additions & 21 deletions xinference/model/llm/pytorch/cogvlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from PIL import Image

from ....core.scheduler import InferenceRequest
from ....model.utils import select_device
from ....types import (
ChatCompletion,
Expand All @@ -35,11 +36,30 @@
)
from ..llm_family import LLMFamilyV1, LLMSpecV1
from .core import PytorchChatModel, PytorchGenerateConfig
from .utils import get_max_src_len

logger = logging.getLogger(__name__)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1


def recur_move_to(item, tgt, criterion_func):
"""
This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
"""
if criterion_func(item):
device_copy = item.to(tgt)
return device_copy
elif isinstance(item, list):
return [recur_move_to(v, tgt, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
else:
return item


class CogVLM2Model(PytorchChatModel):
Expand Down Expand Up @@ -171,11 +191,33 @@ def _image_to_piexl_values(image):
content["image_url"]["url"]
)
assistant = chat_history[i + 1]["content"]
query = query + f" USER: {user} ASSISTANT:"
history.append((query, assistant))
query = query + f" {assistant}"
history.append((user, assistant))
query = assistant # type: ignore
return query, history, [pixel_values]

def get_query_and_history(
self,
prompt: Union[str, List[Dict]],
system_prompt: Optional[str] = None,
chat_history: Optional[List[ChatCompletionMessage]] = None,
):
content, image = self._message_content_to_cogvlm2(prompt)

history = []
history_image = None
if chat_history:
query, history, history_image = self._history_content_to_cogvlm2(
system_prompt, chat_history # type: ignore
)

if image and history_image:
history = []
query = content
else:
image = image if image else history_image
query = content
return query, image, history

def chat(
self,
prompt: Union[str, List[Dict]],
Expand All @@ -193,22 +235,9 @@ def chat(
else 512,
}

content, image = self._message_content_to_cogvlm2(prompt)

history = []
query = ""
history_image = None
if chat_history:
query, history, history_image = self._history_content_to_cogvlm2(
system_prompt, chat_history
)

if image and history_image:
history = []
query = system_prompt + f" USER: {content} ASSISTANT:"
else:
image = image if image else history_image
query = query + f" USER: {content} ASSISTANT:"
query, image, history = self.get_query_and_history(
prompt, system_prompt=system_prompt, chat_history=chat_history
)

input_by_model = self._model.build_conversation_input_ids(
self._tokenizer,
Expand Down Expand Up @@ -314,3 +343,159 @@ def _streaming_chat_response(
),
)
yield chunk

@staticmethod
def build_position_ids(x, attention_mask=None):
"""
Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
"""
# Fix: 参考官方开源代码
if attention_mask is not None:
tmp = x.clone()
tmp[~(attention_mask.bool())] = -1
else:
tmp = x.clone()
# image boi eoi token as LANGUAGE_TOKEN_TYPE
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
)
is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
)
is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
# final position ids
y = torch.zeros_like(x, dtype=torch.long)
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
)
y = y.cumsum(dim=-1)
return y

def get_dtype(self):
return self._torch_type

def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
query, image, history = self.get_query_and_history(
prompt, system_prompt=system_prompt, chat_history=chat_history
)

input_by_model: dict = self._model.build_conversation_input_ids(
self._tokenizer,
query=query,
history=history,
images=image,
template_version="chat",
)
return {
"input_ids": input_by_model["input_ids"], # seq_len
"token_type_ids": input_by_model["token_type_ids"], # seq_len
"attention_mask": input_by_model["attention_mask"], # seq_len
"images": input_by_model["images"],
}

def prepare_sanitize_generate_config(self, req: InferenceRequest):
"""
See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
"""
raw_config = req.inference_kwargs.get("raw_params", {})
temperature = raw_config.get("temperature", None)
if temperature is None:
raw_config["temperature"] = 0.6
top_p = raw_config.get("top_p", None)
if top_p is None:
raw_config["top_p"] = 0.9
return raw_config

def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
context_len = self.get_context_len()
assert isinstance(prompts[0], dict)
images = []
max_length = float("-inf")
for i, feature in enumerate(prompts):
req = req_list[i]
if "images" in feature:
images.append(feature.pop("images", None))
max_src_len = get_max_src_len(context_len, req)
input_ids = feature["input_ids"][-max_src_len:]
req.prompt_tokens = input_ids.tolist()
feature["input_ids"] = input_ids
feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
req.extra_kwargs["attention_mask_seq_len"] = feature[
"attention_mask"
].shape[0]
max_length = max(len(input_ids), max_length)

def pad_to_max_length_internal(feature, max_len, idx):
padding_length = max_len - len(feature["input_ids"])
req_list[idx].padding_len = padding_length
feature["input_ids"] = torch.cat(
[torch.full((padding_length,), 0), feature["input_ids"]]
)
feature["token_type_ids"] = torch.cat(
[
torch.zeros(padding_length, dtype=torch.long),
feature["token_type_ids"],
]
)
feature["attention_mask"] = torch.cat(
[
torch.zeros(padding_length, dtype=torch.long),
feature["attention_mask"],
]
)
return feature

features = [
pad_to_max_length_internal(feature, max_length, i)
for i, feature in enumerate(prompts)
]
batch = {
key: torch.stack([feature[key] for feature in features])
for key in features[0].keys()
}

position_ids = self.build_position_ids(batch["token_type_ids"])
batch["position_ids"] = position_ids

for i in range(len(prompts)):
req = req_list[i]
req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()

if images:
batch["images"] = images

batch = recur_move_to(
batch, self._device, lambda x: isinstance(x, torch.Tensor)
)
dtype = self.get_dtype()
if dtype:
batch = recur_move_to(
batch,
dtype,
lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
)
return batch

def build_decode_token_type_ids(
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
):
token_type_ids = torch.full(
(batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
)
return token_type_ids

def build_decode_position_ids(
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
):
tmp = []
for r in reqs:
r.extra_kwargs["max_position_id"] += 1
tmp.append(r.extra_kwargs["max_position_id"])
position_ids = torch.as_tensor(
tmp, device=self._device, dtype=torch.long
).unsqueeze(1)
return position_ids
Loading
Loading