From 8531cf6bdc49213dc69dc29924f1eb7a60bd92fc Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 13 Oct 2023 20:58:13 +0800 Subject: [PATCH] move BaseModelWorker outside serve.model_worker to make it independent (#2531) --- fastchat/serve/base_model_worker.py | 240 +++++++++++++++++++++++ fastchat/serve/huggingface_api_worker.py | 2 +- fastchat/serve/model_worker.py | 208 +------------------- fastchat/serve/vllm_worker.py | 2 +- 4 files changed, 246 insertions(+), 206 deletions(-) create mode 100644 fastchat/serve/base_model_worker.py diff --git a/fastchat/serve/base_model_worker.py b/fastchat/serve/base_model_worker.py new file mode 100644 index 000000000..d79417184 --- /dev/null +++ b/fastchat/serve/base_model_worker.py @@ -0,0 +1,240 @@ +import asyncio +import threading +import time +from typing import List +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker_id = str(uuid.uuid4())[:8] +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index a356273d9..b16c96147 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -28,7 +28,7 @@ from huggingface_hub import InferenceClient from fastchat.constants import SERVER_ERROR_MSG, ErrorCode -from fastchat.serve.model_worker import BaseModelWorker +from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.utils import build_logger worker_id = str(uuid.uuid4())[:8] diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 59fd1def7..8be916b9e 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -2,21 +2,14 @@ A model worker that executes the model. """ import argparse -import asyncio import base64 import dataclasses import gc -import logging import json import os -import threading -import time from typing import List, Optional import uuid -from fastapi import FastAPI, Request, BackgroundTasks -from fastapi.responses import StreamingResponse, JSONResponse -import requests try: from transformers import ( @@ -37,157 +30,28 @@ from transformers import set_seed import uvicorn -from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG -from fastchat.conversation import get_conv_template +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.model.model_adapter import ( load_model, add_model_args, - get_conversation_template, get_generate_stream_function, ) +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.modules.gptq import GptqConfig from fastchat.modules.awq import AWQConfig from fastchat.modules.exllama import ExllamaConfig from fastchat.modules.gptq import GptqConfig from fastchat.utils import ( build_logger, - pretty_print_semaphore, get_context_length, str_to_torch_dtype, ) -from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length +from fastchat.utils import build_logger, get_context_length worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") -app = FastAPI() - - -def heart_beat_worker(obj): - while True: - time.sleep(WORKER_HEART_BEAT_INTERVAL) - obj.send_heart_beat() - - -class BaseModelWorker: - def __init__( - self, - controller_addr: str, - worker_addr: str, - worker_id: str, - model_path: str, - model_names: List[str], - limit_worker_concurrency: int, - conv_template: str = None, - ): - self.controller_addr = controller_addr - self.worker_addr = worker_addr - self.worker_id = worker_id - if model_path.endswith("/"): - model_path = model_path[:-1] - self.model_names = model_names or [model_path.split("/")[-1]] - self.limit_worker_concurrency = limit_worker_concurrency - if conv_template: - self.conv = get_conv_template(conv_template) - else: - self.conv = get_conversation_template(model_path) - self.conv.sep_style = int(self.conv.sep_style) - self.tokenizer = None - self.context_len = None - self.call_ct = 0 - self.semaphore = None - - self.heart_beat_thread = None - - def init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=heart_beat_worker, - args=(self,), - daemon=True, - ) - self.heart_beat_thread.start() - - def register_to_controller(self): - logger.info("Register to controller") - - url = self.controller_addr + "/register_worker" - data = { - "worker_name": self.worker_addr, - "check_heart_beat": True, - "worker_status": self.get_status(), - } - r = requests.post(url, json=data) - assert r.status_code == 200 - - def send_heart_beat(self): - logger.info( - f"Send heart beat. Models: {self.model_names}. " - f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " - f"call_ct: {self.call_ct}. " - f"worker_id: {self.worker_id}. " - ) - - url = self.controller_addr + "/receive_heart_beat" - - while True: - try: - ret = requests.post( - url, - json={ - "worker_name": self.worker_addr, - "queue_length": self.get_queue_length(), - }, - timeout=5, - ) - exist = ret.json()["exist"] - break - except (requests.exceptions.RequestException, KeyError) as e: - logger.error(f"heart beat error: {e}") - time.sleep(5) - - if not exist: - self.register_to_controller() - - def get_queue_length(self): - if ( - self.semaphore is None - or self.semaphore._value is None - or self.semaphore._waiters is None - ): - return 0 - else: - return ( - self.limit_worker_concurrency - - self.semaphore._value - + len(self.semaphore._waiters) - ) - - def get_status(self): - return { - "model_names": self.model_names, - "speed": 1, - "queue_length": self.get_queue_length(), - } - - def count_token(self, params): - prompt = params["prompt"] - - try: - input_ids = self.tokenizer(prompt).input_ids - input_echo_len = len(input_ids) - except TypeError: - input_echo_len = self.tokenizer.num_tokens(prompt) - - ret = { - "count": input_echo_len, - "error_code": 0, - } - return ret - - def get_conv_template(self): - return {"conv": self.conv} - class ModelWorker(BaseModelWorker): def __init__( @@ -405,70 +269,6 @@ def get_embeddings(self, params): return ret -def release_worker_semaphore(): - worker.semaphore.release() - - -def acquire_worker_semaphore(): - if worker.semaphore is None: - worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) - return worker.semaphore.acquire() - - -def create_background_tasks(): - background_tasks = BackgroundTasks() - background_tasks.add_task(release_worker_semaphore) - return background_tasks - - -@app.post("/worker_generate_stream") -async def api_generate_stream(request: Request): - params = await request.json() - await acquire_worker_semaphore() - generator = worker.generate_stream_gate(params) - background_tasks = create_background_tasks() - return StreamingResponse(generator, background=background_tasks) - - -@app.post("/worker_generate") -async def api_generate(request: Request): - params = await request.json() - await acquire_worker_semaphore() - output = worker.generate_gate(params) - release_worker_semaphore() - return JSONResponse(output) - - -@app.post("/worker_get_embeddings") -async def api_get_embeddings(request: Request): - params = await request.json() - await acquire_worker_semaphore() - embedding = worker.get_embeddings(params) - release_worker_semaphore() - return JSONResponse(content=embedding) - - -@app.post("/worker_get_status") -async def api_get_status(request: Request): - return worker.get_status() - - -@app.post("/count_token") -async def api_count_token(request: Request): - params = await request.json() - return worker.count_token(params) - - -@app.post("/worker_get_conv_template") -async def api_get_conv(request: Request): - return worker.get_conv_template() - - -@app.post("/model_details") -async def api_model_details(request: Request): - return {"context_length": worker.context_len} - - def create_model_worker(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 1a57dc660..30c741c20 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -18,8 +18,8 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.serve.model_worker import ( - BaseModelWorker, logger, worker_id, )