diff --git a/src/benchmarks/locustfile.py b/src/benchmarks/locustfile.py index 2d0f90b..9e7d768 100644 --- a/src/benchmarks/locustfile.py +++ b/src/benchmarks/locustfile.py @@ -1,4 +1,3 @@ -import requests from locust import HttpUser, task headers = { diff --git a/src/fastserve/__main__.py b/src/fastserve/__main__.py index bb8bf76..95a727b 100644 --- a/src/fastserve/__main__.py +++ b/src/fastserve/__main__.py @@ -1,5 +1,3 @@ -import uvicorn - from .base_fastserve import FastServe serve = FastServe() diff --git a/src/fastserve/batching.py b/src/fastserve/batching.py index a4e4284..e92cfdf 100644 --- a/src/fastserve/batching.py +++ b/src/fastserve/batching.py @@ -1,12 +1,10 @@ import logging -import random import signal import time -import uuid from dataclasses import dataclass, field from queue import Empty, Queue from threading import Event, Thread -from typing import Any, Callable, Dict, List +from typing import Any, Callable, List logger = logging.getLogger(__name__) @@ -116,7 +114,7 @@ def _process_queue(self): if not batch: logger.debug("no batch") continue - logger.info(f"Aggregated batch size {len(batch)} in {t1-t0:.2f}s") + logger.info(f"Aggregated batch size {len(batch)} in {t1 - t0:.2f}s") batch_items = [b.item for b in batch] logger.debug(batch_items) try: diff --git a/src/fastserve/models/__main__.py b/src/fastserve/models/__main__.py index 6410db6..a9281fd 100644 --- a/src/fastserve/models/__main__.py +++ b/src/fastserve/models/__main__.py @@ -1,7 +1,6 @@ import argparse from fastserve.utils import get_default_device - from .ssd import FastServeSSD parser = argparse.ArgumentParser(description="Serve models with FastServe") @@ -22,7 +21,6 @@ help="Timeout to aggregate maximum batch size", ) - args = parser.parse_args() app = None diff --git a/src/fastserve/models/llama_cpp.py b/src/fastserve/models/llama_cpp.py index 1062ffe..089e97f 100644 --- a/src/fastserve/models/llama_cpp.py +++ b/src/fastserve/models/llama_cpp.py @@ -1,9 +1,10 @@ +import logging import os -from pathlib import Path from typing import Any from llama_cpp import Llama -from loguru import logger + +logger = logging.getLogger(__name__) # https://huggingface.co/TheBloke/OpenHermes-2-Mistral-7B-GGUF DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf" diff --git a/src/fastserve/models/ssd.py b/src/fastserve/models/ssd.py index d929a08..f5be04f 100644 --- a/src/fastserve/models/ssd.py +++ b/src/fastserve/models/ssd.py @@ -6,7 +6,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel -from fastserve import BaseRequest, FastServe +from fastserve import FastServe class PromptRequest(BaseModel): diff --git a/src/fastserve/models/vllm.py b/src/fastserve/models/vllm.py new file mode 100644 index 0000000..639a85a --- /dev/null +++ b/src/fastserve/models/vllm.py @@ -0,0 +1,47 @@ +import os +from typing import List + +from fastapi import FastAPI +from pydantic import BaseModel + +from vllm import LLM, SamplingParams + +tensor_parallel_size = int(os.environ.get("DEVICES", "1")) +print("tensor_parallel_size: ", tensor_parallel_size) + +llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size) + + +class PromptRequest(BaseModel): + prompt: str + temperature: float = 1 + max_tokens: int = 200 + stop: List[str] = [] + + +class ResponseModel(BaseModel): + prompt: str + prompt_token_ids: List # The token IDs of the prompt. + outputs: List[str] # The output sequences of the request. + finished: bool # Whether the whole request is finished. + + +app = FastAPI() + + +@app.post("/serve", response_model=ResponseModel) +def serve(request: PromptRequest): + sampling_params = SamplingParams( + max_tokens=request.max_tokens, + temperature=request.temperature, + stop=request.stop, + ) + + result = llm.generate(request.prompt, sampling_params=sampling_params)[0] + response = ResponseModel( + prompt=request.prompt, + prompt_token_ids=result.prompt_token_ids, + outputs=result.outputs, + finished=result.finished, + ) + return response