diff --git a/.github/workflows/test_fastapi.yaml b/.github/workflows/test_fastapi.yaml new file mode 100644 index 0000000..cf8889d --- /dev/null +++ b/.github/workflows/test_fastapi.yaml @@ -0,0 +1,46 @@ +name: Run FastApi Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Set up AWS Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-1 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt # Ensure you have this file in the specified directory + pip install awscli httpx python-multipart + working-directory: serving/fastapi_server + + - name: Run AWS S3 Copy Command + run: | + aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + working-directory: serving/gradio_server + + - name: Run tests + run: | + python -m pytest tests # This will run all tests in the specified directory + working-directory: serving/fastapi_server diff --git a/.github/workflows/test_gradio.yaml b/.github/workflows/test_gradio.yaml new file mode 100644 index 0000000..5999f28 --- /dev/null +++ b/.github/workflows/test_gradio.yaml @@ -0,0 +1,46 @@ +name: Run FastApi Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Set up AWS Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-1 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt # Ensure you have this file in the specified directory + pip install awscli + working-directory: serving/gradio_server + + - name: Run AWS S3 Copy Command + run: | + aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + working-directory: serving/gradio_server + + - name: Run tests + run: | + python -m pytest tests # This will run all tests in the specified directory + working-directory: serving/gradio_server diff --git a/.github/workflows/test_streamlit.yaml b/.github/workflows/test_streamlit.yaml new file mode 100644 index 0000000..b668e09 --- /dev/null +++ b/.github/workflows/test_streamlit.yaml @@ -0,0 +1,45 @@ +name: Run FastApi Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Set up AWS Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-1 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt # Ensure you have this file in the specified directory + pip install awscli + working-directory: serving/streamlit_server + + - name: Run AWS S3 Copy Command + run: | + aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + + - name: Run tests + run: | + python -m pytest tests # This will run all tests in the specified directory + working-directory: serving/streamlit_server diff --git a/serving/README.md b/serving/README.md new file mode 100644 index 0000000..a73403d --- /dev/null +++ b/serving/README.md @@ -0,0 +1,68 @@ +# Streamlit + + +## Local Deployment +```bash +streamlit run streamlit_ui.py -- --model_path path/to/model +``` + + +## Build Container +```bash +docker build \ + --build-arg AWS_ACCESS_KEY_ID=key \ + --build-arg AWS_SECRET_ACCESS_KEY="secret_key" \ + -t streamlit_app:latest . +``` + +Run: +```bash +docker run -it --rm -p 8501:8501 streamlit_app:latest +``` + + +# Gradio + +## Local Deployment +```bash +python gradio_ui.py --model_path path/to/model +``` + +## Build Container +```bash +docker build \ + --build-arg AWS_ACCESS_KEY_ID=key \ + --build-arg AWS_SECRET_ACCESS_KEY="secret_key" \ + -t gradio_app:latest . +``` + +Run: +```bash +docker run -it --rm -p 7860:7860 gradio_app:latest +``` + + +# FastAPI +## Local Deployment +```bash +pip install -r requirements.txt +python fastapi_server.py +``` + +## Build Container +```bash +docker build \ + --build-arg AWS_ACCESS_KEY_ID=key \ + --build-arg AWS_SECRET_ACCESS_KEY="secret_key" \ + -t fastapi_app:latest . +``` + +Run: +```bash +docker run -it --rm -p 8000:8000 fastapi_app:latest +``` + +How to make a request: +```bash +curl -X POST "http://localhost:8000/predict/" -F "image=@/path/to/image.jpg" -F "threshold=0.5" +``` \ No newline at end of file diff --git a/serving/fastapi_server/Dockerfile b/serving/fastapi_server/Dockerfile new file mode 100644 index 0000000..825358b --- /dev/null +++ b/serving/fastapi_server/Dockerfile @@ -0,0 +1,22 @@ +FROM python:3.9-slim + +WORKDIR /app + +COPY requirements.txt requirements.txt + +RUN pip install -r requirements.txt +RUN pip install awscli +RUN pip install python-multipart + +COPY . /app + +EXPOSE 8000 + +ARG AWS_SECRET_ACCESS_KEY +ARG AWS_ACCESS_KEY_ID +ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} +ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + +RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + +CMD ["python", "fastapi_server.py"] \ No newline at end of file diff --git a/serving/fastapi_server/fastapi_server.py b/serving/fastapi_server/fastapi_server.py new file mode 100644 index 0000000..4c2a681 --- /dev/null +++ b/serving/fastapi_server/fastapi_server.py @@ -0,0 +1,77 @@ +import os +from fastapi import FastAPI, UploadFile, File +from fastapi.responses import JSONResponse +from PIL import Image +import torch +from transformers import AutoModelForObjectDetection, AutoImageProcessor +from io import BytesIO +from http import HTTPStatus +from typing import Dict + +from fastapi import HTTPException +from PIL import UnidentifiedImageError + +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + +app = FastAPI() + + +model_path = "./rtdetr_model" + +try: + model = AutoModelForObjectDetection.from_pretrained(model_path).to(device).eval() + processor = AutoImageProcessor.from_pretrained(model_path) + model_loaded = True +except Exception as e: + model_loaded = False + print(f"Error loading model: {e}") + +MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} + +@app.get("/") +def _index() -> Dict: + """Health check.""" + response = { + "message": HTTPStatus.OK.phrase, + "status_code": HTTPStatus.OK, + "data": {"model_loaded": model_loaded}, + } + return response + + +def predict(image: Image.Image, threshold: float): + inputs = processor(images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + outputs = model(**inputs) + + target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) + results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] + results = {k: v.detach().cpu() for k, v in results.items()} + + return results + +@app.post("/predict/") +async def inference(image: UploadFile = File(...), threshold: float = 0.5): + try: + image_data = await image.read() + image = Image.open(BytesIO(image_data)) + except UnidentifiedImageError: + raise HTTPException(status_code=400, detail="Invalid image file") + + results = predict(image, threshold) + + output_data = [] + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + output_data.append({ + "class": MODEL_LABEL_MAPPING[label.item()], + "score": score.item(), + "box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] + }) + + return JSONResponse(content={"predictions": output_data}) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/serving/fastapi_server/requirements.txt b/serving/fastapi_server/requirements.txt new file mode 100644 index 0000000..b67b2ba --- /dev/null +++ b/serving/fastapi_server/requirements.txt @@ -0,0 +1,18 @@ +fastapi==0.115.0 +uvicorn===0.31.0 +torch==2.4.0 +torchvision==0.19.0 +transformers==4.44.2 +supervision==0.22.0 +huggingface==0.0.1 +accelerate==0.33.0 +torchmetrics==1.4.1 +albumentations==1.4.14 +pillow==10.4.0 +datasets==2.21.0 +PyYAML==6.0.2 +wandb==0.17.7 +pytest==8.3.3 +pycocotools==2.0.8 +python-dotenv==1.0.1 +# boto3==1.34.158 \ No newline at end of file diff --git a/serving/fastapi_server/tests/test_api.py b/serving/fastapi_server/tests/test_api.py new file mode 100644 index 0000000..ea995b0 --- /dev/null +++ b/serving/fastapi_server/tests/test_api.py @@ -0,0 +1,54 @@ +import os +import pytest +from fastapi.testclient import TestClient +from PIL import Image +import io +from fastapi_server import app # Adjust if your FastAPI app is in a different file + +client = TestClient(app) + +@pytest.fixture +def test_image(): + # Create a simple 100x100 red image for testing + image = Image.new("RGB", (100, 100), color="red") + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='JPEG') + img_byte_arr.seek(0) + return img_byte_arr + +def test_index(): + """Test the health check endpoint.""" + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["status_code"] == 200 + assert data["message"] == "OK" + assert "model_loaded" in data["data"] + +@pytest.mark.skipif(os.environ.get("MODEL_AVAILABLE") != "1", reason="Model not available") +def test_predict(test_image): + """Test the prediction endpoint with an example image.""" + # Simulate sending the image as form data + files = {'image': ('test_image.jpg', test_image, 'image/jpeg')} + response = client.post("/predict/", files=files, data={"threshold": "0.5"}) + + assert response.status_code == 200 + data = response.json() + + # Ensure the response contains the predictions + assert "predictions" in data + for prediction in data["predictions"]: + assert "class" in prediction + assert "score" in prediction + assert "box" in prediction + assert len(prediction["box"]) == 4 # Ensure the box has 4 coordinates + + +def test_predict_invalid_file(): + """Test prediction with invalid file input.""" + files = {'image': ('test_image.txt', io.BytesIO(b"not an image"), 'text/plain')} + response = client.post("/predict/", files=files, data={"threshold": "0.5"}) + + # Expecting a 400 Bad Request for invalid image input + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid image file" diff --git a/serving/gradio_server/Dockerfile b/serving/gradio_server/Dockerfile new file mode 100644 index 0000000..3ab8fab --- /dev/null +++ b/serving/gradio_server/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.9-slim + +WORKDIR /app + +COPY requirements.txt requirements.txt + +RUN pip install -r requirements.txt +RUN pip install awscli + +COPY . /app + +ARG AWS_SECRET_ACCESS_KEY +ARG AWS_ACCESS_KEY_ID +ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} +ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + +RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + + +EXPOSE 7860 +ENV GRADIO_SERVER_NAME="0.0.0.0" + +CMD ["python3", "gradio_ui.py", "--model_path", "./rtdetr_model"] \ No newline at end of file diff --git a/serving/gradio_server/gradio_ui.py b/serving/gradio_server/gradio_ui.py new file mode 100644 index 0000000..fd8a85c --- /dev/null +++ b/serving/gradio_server/gradio_ui.py @@ -0,0 +1,110 @@ +import gradio as gr +from PIL import Image, ImageDraw, ImageFont +import torch +import json +from transformers import AutoModelForObjectDetection, AutoImageProcessor + +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + +def load_model_and_processor(model_path): + # Load model and processor + model = AutoModelForObjectDetection.from_pretrained(model_path) + processor = AutoImageProcessor.from_pretrained(model_path) + + model.to(device) + model.eval() + + return model, processor + +# Define color mapping for classes +CLASS_COLOR_MAPPING = { + "person": "red", + "car": "blue", + "pet": "green" +} + +# Define the model's label mapping (adjust as per your model) +MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} + +def predict(image: Image.Image, threshold: float, model, processor): + # Preprocess image + inputs = processor(images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Perform inference + outputs = model(**inputs) + + # Convert outputs to numpy array + target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) + results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] + results = {k: v.detach().cpu() for k, v in results.items()} + + return results + +def draw_boxes_pillow(image: Image.Image, results): + draw = ImageDraw.Draw(image) + font = ImageFont.load_default(size=25) + + # Add bounding boxes + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + # Un-normalize the bounding boxes + xmin, ymin, xmax, ymax = box + class_label = MODEL_LABEL_MAPPING[label.item()] + + # Draw rectangle + draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=CLASS_COLOR_MAPPING[class_label], width=3) + + # Add class label and score + text = f'{class_label}: {score.item():.2f}' + + text_bbox = draw.textbbox((xmin, ymin), text, font=font) + text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] + text_position = (xmin, ymin - text_height) + + # Draw text background and text + draw.rectangle([text_position, (xmin + text_width, ymin)], fill=CLASS_COLOR_MAPPING[class_label]) + draw.text((xmin, ymin - text_height), text, fill="white", font=font) + + return image + +def gradio_interface(model_path): + model, processor = load_model_and_processor(model_path) + + def inference(image, threshold): + results = predict(image, threshold, model, processor) + image_with_boxes = draw_boxes_pillow(image.copy(), results) + + # Prepare JSON output for predictions + output_data = [] + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + output_data.append({ + "class": MODEL_LABEL_MAPPING[label.item()], + "score": score.item(), + "box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] + }) + + return image_with_boxes, output_data + + # Create Gradio interface + with gr.Blocks() as demo: + gr.Markdown("# Object Detection Inference") + + with gr.Row(): + image_input = gr.Image(type="pil", label="Upload an image") + threshold_input = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Confidence Threshold") + + submit_button = gr.Button("Start Inference") + image_output = gr.Image(label="Detected Objects") + json_output = gr.JSON(label="Predictions in JSON format") + + submit_button.click(inference, inputs=[image_input, threshold_input], outputs=[image_output, json_output]) + + demo.launch() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run the Object Detection Gradio app.") + parser.add_argument('--model_path', type=str, required=True, help="Path to the pre-trained model directory.") + + args = parser.parse_args() + gradio_interface(args.model_path) diff --git a/serving/gradio_server/requirements.txt b/serving/gradio_server/requirements.txt new file mode 100644 index 0000000..4ffb53d --- /dev/null +++ b/serving/gradio_server/requirements.txt @@ -0,0 +1,17 @@ +gradio==4.44.1 +torch==2.4.0 +torchvision==0.19.0 +transformers==4.44.2 +supervision==0.22.0 +huggingface==0.0.1 +accelerate==0.33.0 +torchmetrics==1.4.1 +albumentations==1.4.14 +pillow==10.4.0 +datasets==2.21.0 +PyYAML==6.0.2 +wandb==0.17.7 +pytest==8.3.3 +pycocotools==2.0.8 +python-dotenv==1.0.1 +# boto3==1.34.158 \ No newline at end of file diff --git a/serving/gradio_server/tests/test_ui.py b/serving/gradio_server/tests/test_ui.py new file mode 100644 index 0000000..8b8aa2c --- /dev/null +++ b/serving/gradio_server/tests/test_ui.py @@ -0,0 +1,44 @@ +import pytest +from PIL import Image, ImageDraw +import torch +from transformers import RTDetrForObjectDetection, RTDetrImageProcessor +from gradio_ui import load_model_and_processor, predict, draw_boxes_pillow + +@pytest.fixture(scope='module') +def setup_model(): + model_path = "./rtdetr_model" + model, processor = load_model_and_processor(model_path) + yield model, processor + +def test_load_model_and_processor(setup_model): + model, processor = setup_model + assert isinstance(model, RTDetrForObjectDetection) + assert isinstance(processor, RTDetrImageProcessor) + +def test_predict(setup_model): + model, processor = setup_model + dummy_image = Image.new('RGB', (224, 224), color='white') + threshold = 0.5 + + results = predict(dummy_image, threshold, model, processor) + assert "scores" in results + assert "labels" in results + assert "boxes" in results + assert len(results["scores"]) == len(results["labels"]) == len(results["boxes"]) + +def test_draw_boxes_pillow(setup_model): + model, processor = setup_model + dummy_image = Image.new('RGB', (224, 224), color='white') + + results = { + "scores": torch.tensor([0.9, 0.8]), + "labels": torch.tensor([0, 1]), + "boxes": torch.tensor([[10, 10, 100, 100], [150, 150, 200, 200]]) + } + + image_with_boxes = draw_boxes_pillow(dummy_image.copy(), results) + + assert image_with_boxes != dummy_image + +if __name__ == "__main__": + pytest.main() diff --git a/serving/streamlit_server/Dockerfile b/serving/streamlit_server/Dockerfile new file mode 100644 index 0000000..abdd34b --- /dev/null +++ b/serving/streamlit_server/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.9-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + software-properties-common \ + git \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt requirements.txt + +RUN pip3 install -r requirements.txt +RUN pip install awscli + +EXPOSE 8501 + +COPY . /app +ARG AWS_SECRET_ACCESS_KEY +ARG AWS_ACCESS_KEY_ID +ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} +ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + +RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive + +HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health + +ENTRYPOINT ["streamlit", "run", "streamlit_ui.py", "--server.port=8501", "--server.address=0.0.0.0", "--", "--model_path", "./rtdetr_model"] \ No newline at end of file diff --git a/serving/streamlit_server/requirements.txt b/serving/streamlit_server/requirements.txt new file mode 100644 index 0000000..018fe08 --- /dev/null +++ b/serving/streamlit_server/requirements.txt @@ -0,0 +1,17 @@ +streamlit==1.39.0 +torch==2.4.0 +torchvision==0.19.0 +transformers==4.44.2 +supervision==0.22.0 +huggingface==0.0.1 +accelerate==0.33.0 +torchmetrics==1.4.1 +albumentations==1.4.14 +pillow==10.4.0 +datasets==2.21.0 +PyYAML==6.0.2 +wandb==0.17.7 +pytest==8.3.3 +pycocotools==2.0.8 +python-dotenv==1.0.1 +boto3==1.34.158 \ No newline at end of file diff --git a/serving/streamlit_server/streamlit_ui.py b/serving/streamlit_server/streamlit_ui.py new file mode 100644 index 0000000..8681b74 --- /dev/null +++ b/serving/streamlit_server/streamlit_ui.py @@ -0,0 +1,96 @@ +import streamlit as st +from PIL import Image, ImageDraw, ImageFont +import torch +import json +import argparse +from transformers import AutoModelForObjectDetection, AutoImageProcessor + +device = torch.device("cpu") + +def load_model_and_processor(model_path): + model = AutoModelForObjectDetection.from_pretrained(model_path) + processor = AutoImageProcessor.from_pretrained(model_path) + + model.to(device) + model.eval() + + return model, processor + +CLASS_COLOR_MAPPING = { + "person": "red", + "car": "blue", + "pet": "green" +} + +MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} + +def predict(image: Image.Image, threshold: float, model, processor): + inputs = processor(images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = model(**inputs) + target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) + results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] + results = {k: v.detach().cpu() for k, v in results.items()} + return results + + +def draw_boxes_pillow(image: Image.Image, results): + draw = ImageDraw.Draw(image) + font = ImageFont.load_default(size=25) + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + xmin, ymin, xmax, ymax = box + class_label = MODEL_LABEL_MAPPING[label.item()] + draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=CLASS_COLOR_MAPPING[class_label], width=3) + text = f'{class_label}: {score.item():.2f}' + text_bbox = draw.textbbox((xmin, ymin), text, font=font) + text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] + text_position = (xmin, ymin - text_height) + draw.rectangle([text_position, (xmin + text_width, ymin)], fill=CLASS_COLOR_MAPPING[class_label]) + draw.text((xmin, ymin - text_height), text, fill="white", font=font) + return image + + +def main(model_path): + model, processor = load_model_and_processor(model_path) + + threshold = st.sidebar.slider('Confidence Threshold', 0.0, 1.0, 0.5, 0.05) + + st.title('Object Detection Inference') + + uploaded_image = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png']) + + if uploaded_image is not None: + image = Image.open(uploaded_image) + + st.image(image, caption='Uploaded Image', use_column_width=True) + + if st.button('Start Inference'): + st.write("Running inference...") + results = predict(image, threshold, model, processor) + + st.write("Inference complete! Displaying image with bounding boxes.") + image_with_boxes = image.copy() + image_with_boxes = draw_boxes_pillow(image_with_boxes, results) + + st.image(image_with_boxes, caption='Detected Objects', use_column_width=True) + + output_data = [] + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + output_data.append({ + "class": MODEL_LABEL_MAPPING[label.item()], + "score": score.item(), + "box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] + }) + + st.write("Predictions in JSON format:") + st.json(output_data) + + st.write("Copy-pasteable JSON:") + st.code(json.dumps(output_data, indent=2), language='json') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the Object Detection Streamlit app.") + parser.add_argument('--model_path', type=str, required=True, help="Path to the pre-trained model directory.") + + args = parser.parse_args() + main(args.model_path) diff --git a/serving/streamlit_server/tests/test_ui.py b/serving/streamlit_server/tests/test_ui.py new file mode 100644 index 0000000..f334473 --- /dev/null +++ b/serving/streamlit_server/tests/test_ui.py @@ -0,0 +1,69 @@ +import pytest +from unittest.mock import MagicMock, patch +from PIL import Image, ImageDraw +import torch +import json +from streamlit_ui import (load_model_and_processor, predict, draw_boxes_pillow, + CLASS_COLOR_MAPPING, MODEL_LABEL_MAPPING) + +# Mock for the model and processor +@pytest.fixture +def mock_model_and_processor(): + model = MagicMock() + processor = MagicMock() + return model, processor + +@pytest.fixture +def dummy_image(): + # Create a dummy image for testing + return Image.new('RGB', (100, 100), color='white') + +def test_load_model_and_processor(mock_model_and_processor): + model, processor = mock_model_and_processor + model_path = "dummy/model/path" + + with patch("transformers.AutoModelForObjectDetection.from_pretrained", return_value=model) as mock_model: + with patch("transformers.AutoImageProcessor.from_pretrained", return_value=processor) as mock_processor: + loaded_model, loaded_processor = load_model_and_processor(model_path) + + assert loaded_model == model + assert loaded_processor == processor + mock_model.assert_called_once_with(model_path) + mock_processor.assert_called_once_with(model_path) + +def test_predict(mock_model_and_processor, dummy_image): + model, processor = mock_model_and_processor + threshold = 0.5 + # Mock processor output + processor.post_process_object_detection.return_value = [{ + "scores": torch.tensor([0.9, 0.8]), + "labels": torch.tensor([0, 1]), + "boxes": torch.tensor([[10, 10, 50, 50], [60, 60, 90, 90]]) + }] + + results = predict(dummy_image, threshold, model, processor) + + assert "scores" in results + assert "labels" in results + assert "boxes" in results + assert len(results["scores"]) == 2 + assert len(results["labels"]) == 2 + assert round(results["scores"][0].item(), 2) == 0.9 + assert results["labels"][0].item() == 0 + assert results["boxes"][0].tolist() == [10, 10, 50, 50] + +def test_draw_boxes_pillow(dummy_image): + results = { + "scores": torch.tensor([0.9]), + "labels": torch.tensor([0]), + "boxes": torch.tensor([[10, 10, 50, 50]]) + } + + image_with_boxes = draw_boxes_pillow(dummy_image.copy(), results) + draw = ImageDraw.Draw(image_with_boxes) + + # Check if the color and text are correct + class_label = MODEL_LABEL_MAPPING[results["labels"][0].item()] + expected_color = CLASS_COLOR_MAPPING[class_label] + assert expected_color == "red" # For person class + assert image_with_boxes != dummy_image