Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[NeuralChat] Support Assisted Generation on Multi-nodes #1283

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def build_chatbot(config: PipelineConfig=None):
parameters["optimization_config"] = config.optimization_config
parameters["hf_access_token"] = config.hf_access_token
parameters["assistant_model"] = config.assistant_model
parameters["assistant_host"] = config.assistant_host
parameters["assistant_port"] = config.assistant_port
if config.serving_config and config.serving_config.framework == "vllm":
parameters["use_vllm"] = True
parameters["vllm_engine_params"] = config.serving_config.framework_config
Expand Down
4 changes: 4 additions & 0 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def __init__(self,
loading_config=None,
optimization_config=None,
assistant_model=None,
assistant_host=None,
assistant_port=None,
serving_config=None):
self.model_name_or_path = model_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path
Expand All @@ -482,4 +484,6 @@ def __init__(self,
f"Expect optimization_config be an object of MixedPrecisionConfig, WeightOnlyQuantConfig" + \
" or BitsAndBytesConfig,got {type(self.optimization_config)}."
self.assistant_model = assistant_model
self.assistant_host = assistant_host
self.assistant_port = assistant_port
self.serving_config = serving_config
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ model_name_or_path: "facebook/opt-13b"
device: "cpu"
assistant_model: "facebook/opt-350m"

# multi-node
assistant_host: "0.0.0.0"
assistant_port: 80


# task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune', 'codegen']
tasks_list: ['textchat']
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def load_model(self, kwargs: dict):
self.use_cache = kwargs["use_cache"]
self.ipex_int8 = kwargs["ipex_int8"]
self.assistant_model = kwargs["assistant_model"]
self.assistant_host = kwargs["assistant_host"]
self.assistant_port = kwargs["assistant_port"]
load_model(model_name=kwargs["model_name"],
tokenizer_name=kwargs["tokenizer_name"],
device=kwargs["device"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,13 @@ def load_model(
from transformers import AutoModelForCausalLM
assistant_model_class = AutoModelForCausalLM
print(f"Loading assistant model via {assistant_model_class}")
assis_model = assistant_model_class.from_pretrained(
assist_model = assistant_model_class.from_pretrained(
assistant_model,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype)
assis_model = assis_model.eval().to(device)
assis_model = assis_model.to(memory_format=torch.channels_last)
MODELS[model_name]["assistant_model"] = assis_model
assist_model = assist_model.eval().to(device)
assist_model = assist_model.to(memory_format=torch.channels_last)
MODELS[model_name]["assistant_model"] = assist_model
else:
MODELS[model_name]["assistant_model"] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def init(self, config):
peft_model_path = config.get("peft_model_path", "")
plugin_as_service = config.get("plugin_as_service", False)
assistant_model = config.get("assistant_model", None)
assistant_host = config.get("assistant_host", "0.0.0.0")
assistant_port = config.get("assistant_port", 80)
serving = config.get("serving", None)

serving_config = None
Expand Down Expand Up @@ -270,6 +272,8 @@ def init(self, config):
"loading_config": loading_config,
"optimization_config": optimization_config,
"assistant_model": assistant_model,
"assistant_host": assistant_host,
"assistant_port": assistant_port,
"serving_config": serving_config,
"task": "chat"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .plugin_image2image_api import router as plugin_image2image_router
from .codegen_api import router as codegen_router
from .tgi_api import router as tgi_router
from .assisted_gen_api import router as assist_router

_router = APIRouter()

Expand All @@ -47,7 +48,8 @@
'plugin_audio': plugin_audio_router,
"image2image": plugin_image2image_router,
'codegen': codegen_router,
'tgi': tgi_router
'tgi': tgi_router,
'assist_generation': assist_router
}

def setup_router(api_list, chatbot=None, enable_llm=True, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import httpx
from fastapi.routing import APIRouter
from fastapi import APIRouter
from ...cli.log import logger
from .openai_protocol import (
ChatCompletionRequest,
CompletionRequest,
)


class AssistedGenerationAPIRouter(APIRouter):

def __init__(self) -> None:
super().__init__()

def set_chatbot(self, chatbot, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80) -> None:
self.chatbot = chatbot
self.use_deepspeed = use_deepspeed
self.world_size = world_size
self.host = host
self.port = port
assistant_host = chatbot.assistant_host
assistant_port = chatbot.assistant_port
self.assistant_prefix = 'http://'+assistant_host+":"+assistant_port

def get_chatbot(self):
if self.chatbot is None:
logger.error("Chatbot instance is not found.")
raise RuntimeError("Chatbot instance has not been set.")
return self.chatbot

async def handle_assist_chat(self, request: ChatCompletionRequest):
async with httpx.AsyncClient() as client:
response = await client.get(self.assistant_prefix+"/v1/assist/decode", params=request)
return response.json()

async def handle_assist_decode(self, request: ChatCompletionRequest):
chatbot = self.get_chatbot()
# TODO: complete model inferencing process for assisted model
pass

async def handle_assist_data_transfer(self, request: ChatCompletionRequest):
async with httpx.AsyncClient() as client:
response = await client.get(self.assistant_prefix+"/v1/assist/data_transfer", params=request)
return response.json()


router = AssistedGenerationAPIRouter()


# router for small model to do inferencing
@router.post("/v1/assist/chat")
async def assist_chat(request: ChatCompletionRequest):
return await router.handle_assist_chat(request)


# router for assisted model to do inferencing
@router.post("/v1/assist/decode")
async def assist_decode(request: CompletionRequest):
return await router.handle_assist_decode(request)


# router for assisted model to do data transferring
@router.post("/v1/assist/data_transfer")
async def assist_data_transfer(request: CompletionRequest):
return await router.handle_assist_data_transfer(request)
Loading