Skip to content

Commit

Permalink
move BaseModelWorker outside serve.model_worker to make it independent (
Browse files Browse the repository at this point in the history
  • Loading branch information
liunux4odoo authored Oct 13, 2023
1 parent 7ebc29c commit 8531cf6
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 206 deletions.
240 changes: 240 additions & 0 deletions fastchat/serve/base_model_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import asyncio
import threading
import time
from typing import List
import uuid

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests

from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
from fastchat.conversation import Conversation
from fastchat.utils import pretty_print_semaphore, build_logger


worker_id = str(uuid.uuid4())[:8]
worker = None
logger = None

app = FastAPI()


def heart_beat_worker(obj):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
obj.send_heart_beat()


class BaseModelWorker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
):
global logger, worker

self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_names = model_names or [model_path.split("/")[-1]]
self.limit_worker_concurrency = limit_worker_concurrency
self.conv = self.make_conv_template(conv_template, model_path)
self.conv.sep_style = int(self.conv.sep_style)
self.tokenizer = None
self.context_len = None
self.call_ct = 0
self.semaphore = None

self.heart_beat_thread = None

if logger is None:
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
if worker is None:
worker = self

def make_conv_template(
self,
conv_template: str = None,
model_path: str = None,
) -> Conversation:
"""
can be overrided to costomize the conversation template for different model workers.
"""
from fastchat.conversation import get_conv_template
from fastchat.model.model_adapter import get_conversation_template

if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
return conv

def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker,
args=(self,),
daemon=True,
)
self.heart_beat_thread.start()

def register_to_controller(self):
logger.info("Register to controller")

url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
}
r = requests.post(url, json=data)
assert r.status_code == 200

def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {self.model_names}. "
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
f"call_ct: {self.call_ct}. "
f"worker_id: {self.worker_id}. "
)

url = self.controller_addr + "/receive_heart_beat"

while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except (requests.exceptions.RequestException, KeyError) as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)

if not exist:
self.register_to_controller()

def get_queue_length(self):
if (
self.semaphore is None
or self.semaphore._value is None
or self.semaphore._waiters is None
):
return 0
else:
return (
self.limit_worker_concurrency
- self.semaphore._value
+ len(self.semaphore._waiters)
)

def get_status(self):
return {
"model_names": self.model_names,
"speed": 1,
"queue_length": self.get_queue_length(),
}

def count_token(self, params):
prompt = params["prompt"]

try:
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
except TypeError:
input_echo_len = self.tokenizer.num_tokens(prompt)

ret = {
"count": input_echo_len,
"error_code": 0,
}
return ret

def get_conv_template(self):
return {"conv": self.conv}

def generate_stream_gate(self, params):
raise NotImplementedError

def generate_gate(self, params):
raise NotImplementedError

def get_embeddings(self, params):
raise NotImplementedError


def release_worker_semaphore():
worker.semaphore.release()


def acquire_worker_semaphore():
if worker.semaphore is None:
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
return worker.semaphore.acquire()


def create_background_tasks():
background_tasks = BackgroundTasks()
background_tasks.add_task(release_worker_semaphore)
return background_tasks


@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
await acquire_worker_semaphore()
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks()
return StreamingResponse(generator, background=background_tasks)


@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
await acquire_worker_semaphore()
output = worker.generate_gate(params)
release_worker_semaphore()
return JSONResponse(output)


@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
params = await request.json()
await acquire_worker_semaphore()
embedding = worker.get_embeddings(params)
release_worker_semaphore()
return JSONResponse(content=embedding)


@app.post("/worker_get_status")
async def api_get_status(request: Request):
return worker.get_status()


@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
return worker.count_token(params)


@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
return worker.get_conv_template()


@app.post("/model_details")
async def api_model_details(request: Request):
return {"context_length": worker.context_len}
2 changes: 1 addition & 1 deletion fastchat/serve/huggingface_api_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from huggingface_hub import InferenceClient

from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
from fastchat.serve.model_worker import BaseModelWorker
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.utils import build_logger

worker_id = str(uuid.uuid4())[:8]
Expand Down
Loading

0 comments on commit 8531cf6

Please sign in to comment.