diff --git a/serving/ray_server/Dockerfile b/serving/ray_server/Dockerfile new file mode 100644 index 0000000..3a478a8 --- /dev/null +++ b/serving/ray_server/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.9-slim + +WORKDIR /app + +COPY requirements.txt requirements.txt + +# Install required packages +RUN pip install -r requirements.txt +RUN pip install awscli +RUN pip install python-multipart +RUN pip install ray[serve] # Install Ray and Ray Serve + +COPY . /app + +EXPOSE 8000 + +ARG AWS_SECRET_ACCESS_KEY +ARG AWS_ACCESS_KEY_ID +ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} +ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + +RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + +CMD ["python", "fastapi_server.py"] diff --git a/serving/ray_server/README.md b/serving/ray_server/README.md new file mode 100644 index 0000000..f8bbe89 --- /dev/null +++ b/serving/ray_server/README.md @@ -0,0 +1,21 @@ + +# ray inference server +docker build +docker run + +build container +```bash +docker build \ + --build-arg AWS_ACCESS_KEY_ID=key \ + --build-arg AWS_SECRET_ACCESS_KEY="secret_key" \ + -t alexuvarovskii/object_detection_rayserve:latest . +``` + +```bash +docker run -p 8000:8000 fastapi-rayserve +``` + +run +```bash +curl -X POST "http://localhost:8000/predict/" -F "image=@path_to_your_image.jpg" -F "threshold=0.5" +``` \ No newline at end of file diff --git a/serving/ray_server/fastapi_server.py b/serving/ray_server/fastapi_server.py new file mode 100644 index 0000000..8422526 --- /dev/null +++ b/serving/ray_server/fastapi_server.py @@ -0,0 +1,92 @@ +import os +from fastapi import FastAPI, UploadFile, File +from fastapi.responses import JSONResponse +from PIL import Image +import torch +from transformers import AutoModelForObjectDetection, AutoImageProcessor +from io import BytesIO +from http import HTTPStatus +from typing import Dict +from fastapi import HTTPException +from PIL import UnidentifiedImageError + +import ray +from ray import serve + +MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} + +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + +app = FastAPI() + +@serve.deployment +@serve.ingress(app) +class ObjectDetectionModel: + def __init__(self): + model_path = "./rtdetr_model" + try: + self.model = AutoModelForObjectDetection.from_pretrained(model_path).to(device).eval() + self.processor = AutoImageProcessor.from_pretrained(model_path) + self.model_loaded = True + except Exception as e: + self.model_loaded = False + print(f"Error loading model: {e}") + + @app.get("/") + def _index(self) -> Dict: + """Health check.""" + response = { + "message": HTTPStatus.OK.phrase, + "status_code": HTTPStatus.OK, + "data": {"model_loaded": self.model_loaded}, + } + return response + + def predict(self, image: Image.Image, threshold: float): + inputs = self.processor(images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + outputs = self.model(**inputs) + + target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) + results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] + results = {k: v.detach().cpu() for k, v in results.items()} + + return results + + @app.post("/predict/") + async def inference(self, image: UploadFile = File(...), threshold: float = 0.5): + try: + image_data = await image.read() + image = Image.open(BytesIO(image_data)) + except UnidentifiedImageError: + raise HTTPException(status_code=400, detail="Invalid image file") + + results = self.predict(image, threshold) + + output_data = [] + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + output_data.append({ + "class": MODEL_LABEL_MAPPING[label.item()], + "score": score.item(), + "box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] + }) + + return JSONResponse(content={"predictions": output_data}) + + +object_detection_app = ObjectDetectionModel.bind() + + +if __name__ == "__main__": + import uvicorn + ray.init() + serve.start() + serve.run(object_detection_app) + uvicorn.run(app, host="0.0.0.0", port=8000) + + + +# curl -X POST "http://localhost:8000/predict/" \ +# -F "image=@/Users/alexuvarovskiy/Downloads/Can-a-single-person-own-a-firm-in-India.jpg" \ +# -F "threshold=0.5" \ No newline at end of file diff --git a/serving/ray_server/requirements.txt b/serving/ray_server/requirements.txt new file mode 100644 index 0000000..d3141b6 --- /dev/null +++ b/serving/ray_server/requirements.txt @@ -0,0 +1,17 @@ +fastapi==0.115.0 +uvicorn===0.31.0 +torch==2.4.0 +torchvision==0.19.0 +transformers==4.44.2 +supervision==0.22.0 +huggingface==0.0.1 +accelerate==0.33.0 +torchmetrics==1.4.1 +albumentations==1.4.14 +pillow==10.4.0 +datasets==2.21.0 +PyYAML==6.0.2 +wandb==0.17.7 +pytest==8.3.3 +pycocotools==2.0.8 +python-dotenv==1.0.1 \ No newline at end of file diff --git a/serving/ray_server/tests/test_rayserve.py b/serving/ray_server/tests/test_rayserve.py new file mode 100644 index 0000000..b76c54a --- /dev/null +++ b/serving/ray_server/tests/test_rayserve.py @@ -0,0 +1,68 @@ +import pytest +import ray +import ray.serve as serve +from httpx import AsyncClient +from fastapi.testclient import TestClient +from io import BytesIO +from PIL import Image +from fastapi_server import app, ObjectDetectionModel + +@pytest.fixture(scope="module", autouse=True) +def setup_ray_serve(): + """Fixture to set up Ray and Ray Serve.""" + ray.init(ignore_reinit_error=True) + serve.start(detached=True) + ObjectDetectionModel.bind() + + yield # Run tests + + serve.shutdown() + ray.shutdown() + + +@pytest.mark.asyncio +async def test_health_check(): + """Test the health check endpoint at `/`.""" + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.get("/") + assert response.status_code == 200 + assert response.json()["data"]["model_loaded"] == True + + +def create_test_image(): + """Creates an in-memory test image.""" + image = Image.new("RGB", (100, 100), color="white") # Create a simple white image + img_byte_arr = BytesIO() + image.save(img_byte_arr, format="JPEG") + img_byte_arr.seek(0) + return img_byte_arr + + +@pytest.mark.asyncio +async def test_predict_endpoint(): + image = create_test_image() + + files = {"image": ("test.jpg", image, "image/jpeg")} + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/predict/", files=files, data={"threshold": "0.5"}) + + assert response.status_code == 200 + + json_response = response.json() + assert "predictions" in json_response + assert isinstance(json_response["predictions"], list) + + +@pytest.mark.asyncio +async def test_invalid_image_upload(): + """Test uploading an invalid image file.""" + invalid_image_data = BytesIO(b"this is not an image") + + files = {"image": ("test.txt", invalid_image_data, "text/plain")} + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/predict/", files=files, data={"threshold": "0.5"}) + + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid image file"