Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Dec 6, 2023
1 parent 00ed0ec commit 60db946
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 12 deletions.
1 change: 0 additions & 1 deletion src/benchmarks/locustfile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import requests
from locust import HttpUser, task

headers = {
Expand Down
2 changes: 0 additions & 2 deletions src/fastserve/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import uvicorn

from .base_fastserve import FastServe

serve = FastServe()
Expand Down
6 changes: 2 additions & 4 deletions src/fastserve/batching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
import random
import signal
import time
import uuid
from dataclasses import dataclass, field
from queue import Empty, Queue
from threading import Event, Thread
from typing import Any, Callable, Dict, List
from typing import Any, Callable, List

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,7 +114,7 @@ def _process_queue(self):
if not batch:
logger.debug("no batch")
continue
logger.info(f"Aggregated batch size {len(batch)} in {t1-t0:.2f}s")
logger.info(f"Aggregated batch size {len(batch)} in {t1 - t0:.2f}s")
batch_items = [b.item for b in batch]
logger.debug(batch_items)
try:
Expand Down
2 changes: 0 additions & 2 deletions src/fastserve/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse

from fastserve.utils import get_default_device

from .ssd import FastServeSSD

parser = argparse.ArgumentParser(description="Serve models with FastServe")
Expand All @@ -22,7 +21,6 @@
help="Timeout to aggregate maximum batch size",
)


args = parser.parse_args()

app = None
Expand Down
5 changes: 3 additions & 2 deletions src/fastserve/models/llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
from pathlib import Path
from typing import Any

from llama_cpp import Llama
from loguru import logger

logger = logging.getLogger(__name__)

# https://huggingface.co/TheBloke/OpenHermes-2-Mistral-7B-GGUF
DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf"
Expand Down
2 changes: 1 addition & 1 deletion src/fastserve/models/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from fastserve import BaseRequest, FastServe
from fastserve import FastServe


class PromptRequest(BaseModel):
Expand Down
47 changes: 47 additions & 0 deletions src/fastserve/models/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
from typing import List

from fastapi import FastAPI
from pydantic import BaseModel

from vllm import LLM, SamplingParams

tensor_parallel_size = int(os.environ.get("DEVICES", "1"))
print("tensor_parallel_size: ", tensor_parallel_size)

llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size)


class PromptRequest(BaseModel):
prompt: str
temperature: float = 1
max_tokens: int = 200
stop: List[str] = []


class ResponseModel(BaseModel):
prompt: str
prompt_token_ids: List # The token IDs of the prompt.
outputs: List[str] # The output sequences of the request.
finished: bool # Whether the whole request is finished.


app = FastAPI()


@app.post("/serve", response_model=ResponseModel)
def serve(request: PromptRequest):
sampling_params = SamplingParams(
max_tokens=request.max_tokens,
temperature=request.temperature,
stop=request.stop,
)

result = llm.generate(request.prompt, sampling_params=sampling_params)[0]
response = ResponseModel(
prompt=request.prompt,
prompt_token_ids=result.prompt_token_ids,
outputs=result.outputs,
finished=result.finished,
)
return response

0 comments on commit 60db946

Please sign in to comment.