-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c4a6314
commit 58ebc2c
Showing
5 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
FROM python:3.9-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY requirements.txt requirements.txt | ||
|
||
# Install required packages | ||
RUN pip install -r requirements.txt | ||
RUN pip install awscli | ||
RUN pip install python-multipart | ||
RUN pip install ray[serve] # Install Ray and Ray Serve | ||
|
||
COPY . /app | ||
|
||
EXPOSE 8000 | ||
|
||
ARG AWS_SECRET_ACCESS_KEY | ||
ARG AWS_ACCESS_KEY_ID | ||
ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} | ||
ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} | ||
|
||
RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive | ||
|
||
CMD ["python", "fastapi_server.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
# ray inference server | ||
docker build | ||
docker run | ||
|
||
build container | ||
```bash | ||
docker build \ | ||
--build-arg AWS_ACCESS_KEY_ID=key \ | ||
--build-arg AWS_SECRET_ACCESS_KEY="secret_key" \ | ||
-t alexuvarovskii/object_detection_rayserve:latest . | ||
``` | ||
|
||
```bash | ||
docker run -p 8000:8000 fastapi-rayserve | ||
``` | ||
|
||
run | ||
```bash | ||
curl -X POST "http://localhost:8000/predict/" -F "image=@path_to_your_image.jpg" -F "threshold=0.5" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
from fastapi import FastAPI, UploadFile, File | ||
from fastapi.responses import JSONResponse | ||
from PIL import Image | ||
import torch | ||
from transformers import AutoModelForObjectDetection, AutoImageProcessor | ||
from io import BytesIO | ||
from http import HTTPStatus | ||
from typing import Dict | ||
from fastapi import HTTPException | ||
from PIL import UnidentifiedImageError | ||
|
||
import ray | ||
from ray import serve | ||
|
||
MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} | ||
|
||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | ||
|
||
app = FastAPI() | ||
|
||
@serve.deployment | ||
@serve.ingress(app) | ||
class ObjectDetectionModel: | ||
def __init__(self): | ||
model_path = "./rtdetr_model" | ||
try: | ||
self.model = AutoModelForObjectDetection.from_pretrained(model_path).to(device).eval() | ||
self.processor = AutoImageProcessor.from_pretrained(model_path) | ||
self.model_loaded = True | ||
except Exception as e: | ||
self.model_loaded = False | ||
print(f"Error loading model: {e}") | ||
|
||
@app.get("/") | ||
def _index(self) -> Dict: | ||
"""Health check.""" | ||
response = { | ||
"message": HTTPStatus.OK.phrase, | ||
"status_code": HTTPStatus.OK, | ||
"data": {"model_loaded": self.model_loaded}, | ||
} | ||
return response | ||
|
||
def predict(self, image: Image.Image, threshold: float): | ||
inputs = self.processor(images=image, return_tensors="pt") | ||
inputs = {k: v.to(device) for k, v in inputs.items()} | ||
|
||
outputs = self.model(**inputs) | ||
|
||
target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) | ||
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] | ||
results = {k: v.detach().cpu() for k, v in results.items()} | ||
|
||
return results | ||
|
||
@app.post("/predict/") | ||
async def inference(self, image: UploadFile = File(...), threshold: float = 0.5): | ||
try: | ||
image_data = await image.read() | ||
image = Image.open(BytesIO(image_data)) | ||
except UnidentifiedImageError: | ||
raise HTTPException(status_code=400, detail="Invalid image file") | ||
|
||
results = self.predict(image, threshold) | ||
|
||
output_data = [] | ||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | ||
output_data.append({ | ||
"class": MODEL_LABEL_MAPPING[label.item()], | ||
"score": score.item(), | ||
"box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] | ||
}) | ||
|
||
return JSONResponse(content={"predictions": output_data}) | ||
|
||
|
||
object_detection_app = ObjectDetectionModel.bind() | ||
|
||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
ray.init() | ||
serve.start() | ||
serve.run(object_detection_app) | ||
uvicorn.run(app, host="0.0.0.0", port=8000) | ||
|
||
|
||
|
||
# curl -X POST "http://localhost:8000/predict/" \ | ||
# -F "image=@/Users/alexuvarovskiy/Downloads/Can-a-single-person-own-a-firm-in-India.jpg" \ | ||
# -F "threshold=0.5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
fastapi==0.115.0 | ||
uvicorn===0.31.0 | ||
torch==2.4.0 | ||
torchvision==0.19.0 | ||
transformers==4.44.2 | ||
supervision==0.22.0 | ||
huggingface==0.0.1 | ||
accelerate==0.33.0 | ||
torchmetrics==1.4.1 | ||
albumentations==1.4.14 | ||
pillow==10.4.0 | ||
datasets==2.21.0 | ||
PyYAML==6.0.2 | ||
wandb==0.17.7 | ||
pytest==8.3.3 | ||
pycocotools==2.0.8 | ||
python-dotenv==1.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import pytest | ||
import ray | ||
import ray.serve as serve | ||
from httpx import AsyncClient | ||
from fastapi.testclient import TestClient | ||
from io import BytesIO | ||
from PIL import Image | ||
from fastapi_server import app, ObjectDetectionModel | ||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def setup_ray_serve(): | ||
"""Fixture to set up Ray and Ray Serve.""" | ||
ray.init(ignore_reinit_error=True) | ||
serve.start(detached=True) | ||
ObjectDetectionModel.bind() | ||
|
||
yield # Run tests | ||
|
||
serve.shutdown() | ||
ray.shutdown() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_health_check(): | ||
"""Test the health check endpoint at `/`.""" | ||
async with AsyncClient(app=app, base_url="http://test") as ac: | ||
response = await ac.get("/") | ||
assert response.status_code == 200 | ||
assert response.json()["data"]["model_loaded"] == True | ||
|
||
|
||
def create_test_image(): | ||
"""Creates an in-memory test image.""" | ||
image = Image.new("RGB", (100, 100), color="white") # Create a simple white image | ||
img_byte_arr = BytesIO() | ||
image.save(img_byte_arr, format="JPEG") | ||
img_byte_arr.seek(0) | ||
return img_byte_arr | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_predict_endpoint(): | ||
image = create_test_image() | ||
|
||
files = {"image": ("test.jpg", image, "image/jpeg")} | ||
|
||
async with AsyncClient(app=app, base_url="http://test") as ac: | ||
response = await ac.post("/predict/", files=files, data={"threshold": "0.5"}) | ||
|
||
assert response.status_code == 200 | ||
|
||
json_response = response.json() | ||
assert "predictions" in json_response | ||
assert isinstance(json_response["predictions"], list) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invalid_image_upload(): | ||
"""Test uploading an invalid image file.""" | ||
invalid_image_data = BytesIO(b"this is not an image") | ||
|
||
files = {"image": ("test.txt", invalid_image_data, "text/plain")} | ||
|
||
async with AsyncClient(app=app, base_url="http://test") as ac: | ||
response = await ac.post("/predict/", files=files, data={"threshold": "0.5"}) | ||
|
||
assert response.status_code == 400 | ||
assert response.json()["detail"] == "Invalid image file" |