Skip to content

Commit

Permalink
Add ray server
Browse files Browse the repository at this point in the history
  • Loading branch information
alexuvarovskyi committed Oct 22, 2024
1 parent c4a6314 commit 58ebc2c
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 0 deletions.
24 changes: 24 additions & 0 deletions serving/ray_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM python:3.9-slim

WORKDIR /app

COPY requirements.txt requirements.txt

# Install required packages
RUN pip install -r requirements.txt
RUN pip install awscli
RUN pip install python-multipart
RUN pip install ray[serve] # Install Ray and Ray Serve

COPY . /app

EXPOSE 8000

ARG AWS_SECRET_ACCESS_KEY
ARG AWS_ACCESS_KEY_ID
ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}

RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive

CMD ["python", "fastapi_server.py"]
21 changes: 21 additions & 0 deletions serving/ray_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

# ray inference server
docker build
docker run

build container
```bash
docker build \
--build-arg AWS_ACCESS_KEY_ID=key \
--build-arg AWS_SECRET_ACCESS_KEY="secret_key" \
-t alexuvarovskii/object_detection_rayserve:latest .
```

```bash
docker run -p 8000:8000 fastapi-rayserve
```

run
```bash
curl -X POST "http://localhost:8000/predict/" -F "image=@path_to_your_image.jpg" -F "threshold=0.5"
```
92 changes: 92 additions & 0 deletions serving/ray_server/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from transformers import AutoModelForObjectDetection, AutoImageProcessor
from io import BytesIO
from http import HTTPStatus
from typing import Dict
from fastapi import HTTPException
from PIL import UnidentifiedImageError

import ray
from ray import serve

MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"}

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

app = FastAPI()

@serve.deployment
@serve.ingress(app)
class ObjectDetectionModel:
def __init__(self):
model_path = "./rtdetr_model"
try:
self.model = AutoModelForObjectDetection.from_pretrained(model_path).to(device).eval()
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model_loaded = True
except Exception as e:
self.model_loaded = False
print(f"Error loading model: {e}")

@app.get("/")
def _index(self) -> Dict:
"""Health check."""
response = {
"message": HTTPStatus.OK.phrase,
"status_code": HTTPStatus.OK,
"data": {"model_loaded": self.model_loaded},
}
return response

def predict(self, image: Image.Image, threshold: float):
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = self.model(**inputs)

target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width)
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
results = {k: v.detach().cpu() for k, v in results.items()}

return results

@app.post("/predict/")
async def inference(self, image: UploadFile = File(...), threshold: float = 0.5):
try:
image_data = await image.read()
image = Image.open(BytesIO(image_data))
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")

results = self.predict(image, threshold)

output_data = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
output_data.append({
"class": MODEL_LABEL_MAPPING[label.item()],
"score": score.item(),
"box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()]
})

return JSONResponse(content={"predictions": output_data})


object_detection_app = ObjectDetectionModel.bind()


if __name__ == "__main__":
import uvicorn
ray.init()
serve.start()
serve.run(object_detection_app)
uvicorn.run(app, host="0.0.0.0", port=8000)



# curl -X POST "http://localhost:8000/predict/" \
# -F "image=@/Users/alexuvarovskiy/Downloads/Can-a-single-person-own-a-firm-in-India.jpg" \
# -F "threshold=0.5"
17 changes: 17 additions & 0 deletions serving/ray_server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
fastapi==0.115.0
uvicorn===0.31.0
torch==2.4.0
torchvision==0.19.0
transformers==4.44.2
supervision==0.22.0
huggingface==0.0.1
accelerate==0.33.0
torchmetrics==1.4.1
albumentations==1.4.14
pillow==10.4.0
datasets==2.21.0
PyYAML==6.0.2
wandb==0.17.7
pytest==8.3.3
pycocotools==2.0.8
python-dotenv==1.0.1
68 changes: 68 additions & 0 deletions serving/ray_server/tests/test_rayserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
import ray
import ray.serve as serve
from httpx import AsyncClient
from fastapi.testclient import TestClient
from io import BytesIO
from PIL import Image
from fastapi_server import app, ObjectDetectionModel

@pytest.fixture(scope="module", autouse=True)
def setup_ray_serve():
"""Fixture to set up Ray and Ray Serve."""
ray.init(ignore_reinit_error=True)
serve.start(detached=True)
ObjectDetectionModel.bind()

yield # Run tests

serve.shutdown()
ray.shutdown()


@pytest.mark.asyncio
async def test_health_check():
"""Test the health check endpoint at `/`."""
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/")
assert response.status_code == 200
assert response.json()["data"]["model_loaded"] == True


def create_test_image():
"""Creates an in-memory test image."""
image = Image.new("RGB", (100, 100), color="white") # Create a simple white image
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="JPEG")
img_byte_arr.seek(0)
return img_byte_arr


@pytest.mark.asyncio
async def test_predict_endpoint():
image = create_test_image()

files = {"image": ("test.jpg", image, "image/jpeg")}

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/predict/", files=files, data={"threshold": "0.5"})

assert response.status_code == 200

json_response = response.json()
assert "predictions" in json_response
assert isinstance(json_response["predictions"], list)


@pytest.mark.asyncio
async def test_invalid_image_upload():
"""Test uploading an invalid image file."""
invalid_image_data = BytesIO(b"this is not an image")

files = {"image": ("test.txt", invalid_image_data, "text/plain")}

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/predict/", files=files, data={"threshold": "0.5"})

assert response.status_code == 400
assert response.json()["detail"] == "Invalid image file"

0 comments on commit 58ebc2c

Please sign in to comment.