-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add streamlit UI, Dockerfile, Tests, CI
- Loading branch information
1 parent
75753fd
commit 0d97ec6
Showing
6 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: Run FastApi Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: '3.9' | ||
|
||
- name: Set up AWS Credentials | ||
uses: aws-actions/configure-aws-credentials@v2 | ||
with: | ||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} | ||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} | ||
aws-region: us-east-1 | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt # Ensure you have this file in the specified directory | ||
pip install awscli | ||
working-directory: serving/gradio_server | ||
|
||
- name: Run AWS S3 Copy Command | ||
run: | | ||
aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive | ||
working-directory: serving/gradio_server | ||
|
||
- name: Run tests | ||
run: | | ||
python -m pytest tests # This will run all tests in the specified directory | ||
working-directory: serving/gradio_server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
FROM python:3.9-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY requirements.txt requirements.txt | ||
|
||
RUN pip install -r requirements.txt | ||
RUN pip install awscli | ||
|
||
COPY . /app | ||
|
||
ARG AWS_SECRET_ACCESS_KEY | ||
ARG AWS_ACCESS_KEY_ID | ||
ENV AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} | ||
ENV AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} | ||
|
||
RUN aws s3 cp s3://mlp-data-2024/rtdetr_model/ ./rtdetr_model --recursive | ||
|
||
|
||
EXPOSE 7860 | ||
ENV GRADIO_SERVER_NAME="0.0.0.0" | ||
|
||
CMD ["python3", "gradio_ui.py", "--model_path", "./rtdetr_model"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import gradio as gr | ||
from PIL import Image, ImageDraw, ImageFont | ||
import torch | ||
import json | ||
from transformers import AutoModelForObjectDetection, AutoImageProcessor | ||
|
||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | ||
|
||
def load_model_and_processor(model_path): | ||
# Load model and processor | ||
model = AutoModelForObjectDetection.from_pretrained(model_path) | ||
processor = AutoImageProcessor.from_pretrained(model_path) | ||
|
||
model.to(device) | ||
model.eval() | ||
|
||
return model, processor | ||
|
||
# Define color mapping for classes | ||
CLASS_COLOR_MAPPING = { | ||
"person": "red", | ||
"car": "blue", | ||
"pet": "green" | ||
} | ||
|
||
# Define the model's label mapping (adjust as per your model) | ||
MODEL_LABEL_MAPPING = {0: "person", 1: "car", 2: "pet"} | ||
|
||
def predict(image: Image.Image, threshold: float, model, processor): | ||
# Preprocess image | ||
inputs = processor(images=image, return_tensors="pt") | ||
inputs = {k: v.to(device) for k, v in inputs.items()} | ||
|
||
# Perform inference | ||
outputs = model(**inputs) | ||
|
||
# Convert outputs to numpy array | ||
target_sizes = torch.tensor([image.size[::-1]]) # target size in (height, width) | ||
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] | ||
results = {k: v.detach().cpu() for k, v in results.items()} | ||
|
||
return results | ||
|
||
def draw_boxes_pillow(image: Image.Image, results): | ||
draw = ImageDraw.Draw(image) | ||
font = ImageFont.load_default(size=25) | ||
|
||
# Add bounding boxes | ||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | ||
# Un-normalize the bounding boxes | ||
xmin, ymin, xmax, ymax = box | ||
class_label = MODEL_LABEL_MAPPING[label.item()] | ||
|
||
# Draw rectangle | ||
draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=CLASS_COLOR_MAPPING[class_label], width=3) | ||
|
||
# Add class label and score | ||
text = f'{class_label}: {score.item():.2f}' | ||
|
||
text_bbox = draw.textbbox((xmin, ymin), text, font=font) | ||
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | ||
text_position = (xmin, ymin - text_height) | ||
|
||
# Draw text background and text | ||
draw.rectangle([text_position, (xmin + text_width, ymin)], fill=CLASS_COLOR_MAPPING[class_label]) | ||
draw.text((xmin, ymin - text_height), text, fill="white", font=font) | ||
|
||
return image | ||
|
||
def gradio_interface(model_path): | ||
model, processor = load_model_and_processor(model_path) | ||
|
||
def inference(image, threshold): | ||
results = predict(image, threshold, model, processor) | ||
image_with_boxes = draw_boxes_pillow(image.copy(), results) | ||
|
||
# Prepare JSON output for predictions | ||
output_data = [] | ||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | ||
output_data.append({ | ||
"class": MODEL_LABEL_MAPPING[label.item()], | ||
"score": score.item(), | ||
"box": [box[0].item(), box[1].item(), box[2].item(), box[3].item()] | ||
}) | ||
|
||
return image_with_boxes, output_data | ||
|
||
# Create Gradio interface | ||
with gr.Blocks() as demo: | ||
gr.Markdown("# Object Detection Inference") | ||
|
||
with gr.Row(): | ||
image_input = gr.Image(type="pil", label="Upload an image") | ||
threshold_input = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Confidence Threshold") | ||
|
||
submit_button = gr.Button("Start Inference") | ||
image_output = gr.Image(label="Detected Objects") | ||
json_output = gr.JSON(label="Predictions in JSON format") | ||
|
||
submit_button.click(inference, inputs=[image_input, threshold_input], outputs=[image_output, json_output]) | ||
|
||
demo.launch() | ||
|
||
if __name__ == "__main__": | ||
import argparse | ||
parser = argparse.ArgumentParser(description="Run the Object Detection Gradio app.") | ||
parser.add_argument('--model_path', type=str, required=True, help="Path to the pre-trained model directory.") | ||
|
||
args = parser.parse_args() | ||
gradio_interface(args.model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
gradio==4.44.1 | ||
torch==2.4.0 | ||
torchvision==0.19.0 | ||
transformers==4.44.2 | ||
supervision==0.22.0 | ||
huggingface==0.0.1 | ||
accelerate==0.33.0 | ||
torchmetrics==1.4.1 | ||
albumentations==1.4.14 | ||
pillow==10.4.0 | ||
datasets==2.21.0 | ||
PyYAML==6.0.2 | ||
wandb==0.17.7 | ||
pytest==8.3.3 | ||
pycocotools==2.0.8 | ||
python-dotenv==1.0.1 | ||
# boto3==1.34.158 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
from PIL import Image, ImageDraw | ||
import torch | ||
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor | ||
from gradio_ui import load_model_and_processor, predict, draw_boxes_pillow | ||
|
||
@pytest.fixture(scope='module') | ||
def setup_model(): | ||
model_path = "./rtdetr_model" | ||
model, processor = load_model_and_processor(model_path) | ||
yield model, processor | ||
|
||
def test_load_model_and_processor(setup_model): | ||
model, processor = setup_model | ||
assert isinstance(model, RTDetrForObjectDetection) | ||
assert isinstance(processor, RTDetrImageProcessor) | ||
|
||
def test_predict(setup_model): | ||
model, processor = setup_model | ||
dummy_image = Image.new('RGB', (224, 224), color='white') | ||
threshold = 0.5 | ||
|
||
results = predict(dummy_image, threshold, model, processor) | ||
assert "scores" in results | ||
assert "labels" in results | ||
assert "boxes" in results | ||
assert len(results["scores"]) == len(results["labels"]) == len(results["boxes"]) | ||
|
||
def test_draw_boxes_pillow(setup_model): | ||
model, processor = setup_model | ||
dummy_image = Image.new('RGB', (224, 224), color='white') | ||
|
||
results = { | ||
"scores": torch.tensor([0.9, 0.8]), | ||
"labels": torch.tensor([0, 1]), | ||
"boxes": torch.tensor([[10, 10, 100, 100], [150, 150, 200, 200]]) | ||
} | ||
|
||
image_with_boxes = draw_boxes_pillow(dummy_image.copy(), results) | ||
|
||
assert image_with_boxes != dummy_image | ||
|
||
if __name__ == "__main__": | ||
pytest.main() |