Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Nov 29, 2023
1 parent a0e4f12 commit 6e55c54
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
5 changes: 3 additions & 2 deletions src/fastserve/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uvicorn

from .fastserve import app
from .fastserve import FastServe

uvicorn.run(app)
serve = FastServe()
serve.run_server()
16 changes: 11 additions & 5 deletions src/fastserve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from threading import Event, Thread
from typing import Any, Callable, Dict, List

from loguru import logger


class BatchedQueue:
def __init__(self, timeout=1.0, bs=1):
Expand Down Expand Up @@ -100,21 +102,25 @@ def __init__(
self._thread.start()

def _process_queue(self):
print("Started processing")
logger.info("Started processing")
while True:
if self._cancel_signal.is_set():
print("Stopped batch processor")
logger.info("Stopped batch processor")
return
t0 = time.time()
batch: List[WaitedObject] = self._batched_queue.get()
logger.debug(batch)
t1 = time.time()
# print(f"waited {t1-t0:.2f}s for batch")
logger.debug(f"waited {t1-t0:.2f}s for batch")
if not batch:
# print("no batch")
logger.debug("no batch")
continue
batch_items = [b.item for b in batch]
# print(batch_items)
logger.debug(batch_items)
results = self.func(batch_items)
if not isinstance(results, list):
logger.error(f"returned results must be List but is {type(results)}")
logger.debug(results)
for b, result in zip(batch, results):
b.set_result(result)

Expand Down
50 changes: 32 additions & 18 deletions src/fastserve/fastserve.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
from typing import List
from typing import Any, List

from fastapi import FastAPI
from pydantic import BaseModel

from .models.llama_cpp import LlamaCppLLM
from .batching import BatchProcessor


class PromptRequest(BaseModel):
prompt: str
temperature: float = 0.2
max_tokens: int = 60
stop: List[str] = []
class BaseRequest(BaseModel):
request: Any


app = FastAPI()
llm = LlamaCppLLM(model_path="openhermes-2-mistral-7b.Q5_K_M.gguf")
class FastServe:
def __init__(self, batch_size=2, timeout=0.5) -> None:
self.batch_processing = BatchProcessor(
func=self.handle, bs=batch_size, timeout=timeout
)
self._app = FastAPI()

@self._app.on_event("shutdown")
def shutdown_event():
self.batch_processing.cancel()

@app.post("/serve")
def serve(prompt: PromptRequest):
result = llm(
prompt=prompt.prompt,
temperature=prompt.temperature,
max_tokens=prompt.max_tokens,
stop=prompt.stop,
)
return result
def serve(
self,
):
@self._app.post(path="/endpoint")
def api(request: BaseRequest):
wait_obj = self.batch_processing.process(request)
return wait_obj.get()

def handle(self, batch: List[BaseRequest]):
n = len(batch)
return n * [0.5 * n]

def run_server(
self,
):
self.serve()
import uvicorn

uvicorn.run(self._app)
28 changes: 28 additions & 0 deletions src/fastserve/serve_llama_cpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

from fastapi import FastAPI
from pydantic import BaseModel

from .models.llama_cpp import LlamaCppLLM


class PromptRequest(BaseModel):
prompt: str
temperature: float = 0.2
max_tokens: int = 60
stop: List[str] = []


app = FastAPI()
llm = LlamaCppLLM(model_path="openhermes-2-mistral-7b.Q5_K_M.gguf")


@app.post("/serve")
def serve(prompt: PromptRequest):
result = llm(
prompt=prompt.prompt,
temperature=prompt.temperature,
max_tokens=prompt.max_tokens,
stop=prompt.stop,
)
return result

0 comments on commit 6e55c54

Please sign in to comment.