-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update object detection guide #9456
Changes from 7 commits
e18a9d1
4b1df28
d5d2222
e52985d
b14d30c
7504773
924cf30
b4e8352
161355a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
"gradio": minor | ||
--- | ||
|
||
feat:Update object detection guide |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import time | ||
import cv2 | ||
import numpy as np | ||
import onnxruntime # type: ignore | ||
|
||
from utils import draw_detections # type: ignore | ||
|
||
|
||
class YOLOv10: | ||
def __init__(self, path): | ||
# Initialize model | ||
self.initialize_model(path) | ||
|
||
def __call__(self, image): | ||
return self.detect_objects(image) | ||
|
||
def initialize_model(self, path): | ||
self.session = onnxruntime.InferenceSession( | ||
path, providers=onnxruntime.get_available_providers() | ||
) | ||
# Get model info | ||
self.get_input_details() | ||
self.get_output_details() | ||
|
||
def detect_objects(self, image, conf_threshold=0.3): | ||
input_tensor = self.prepare_input(image) | ||
|
||
# Perform inference on the image | ||
new_image = self.inference(image, input_tensor, conf_threshold) | ||
|
||
return new_image | ||
|
||
def prepare_input(self, image): | ||
self.img_height, self.img_width = image.shape[:2] | ||
|
||
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
# Resize input image | ||
input_img = cv2.resize(input_img, (self.input_width, self.input_height)) | ||
|
||
# Scale input pixel values to 0 to 1 | ||
input_img = input_img / 255.0 | ||
input_img = input_img.transpose(2, 0, 1) | ||
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) | ||
|
||
return input_tensor | ||
|
||
def inference(self, image, input_tensor, conf_threshold=0.3): | ||
start = time.perf_counter() | ||
outputs = self.session.run( | ||
self.output_names, {self.input_names[0]: input_tensor} | ||
) | ||
|
||
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms") | ||
( | ||
boxes, | ||
scores, | ||
class_ids, | ||
) = self.process_output(outputs, conf_threshold) | ||
return self.draw_detections(image, boxes, scores, class_ids) | ||
|
||
def process_output(self, output, conf_threshold=0.3): | ||
predictions = np.squeeze(output[0]) | ||
|
||
# Filter out object confidence scores below threshold | ||
scores = predictions[:, 4] | ||
predictions = predictions[scores > conf_threshold, :] | ||
scores = scores[scores > conf_threshold] | ||
|
||
if len(scores) == 0: | ||
return [], [], [] | ||
|
||
# Get the class with the highest confidence | ||
class_ids = predictions[:, 5].astype(int) | ||
|
||
# Get bounding boxes for each object | ||
boxes = self.extract_boxes(predictions) | ||
|
||
return boxes, scores, class_ids | ||
|
||
def extract_boxes(self, predictions): | ||
# Extract boxes from predictions | ||
boxes = predictions[:, :4] | ||
|
||
# Scale boxes to original image dimensions | ||
boxes = self.rescale_boxes(boxes) | ||
|
||
# Convert boxes to xyxy format | ||
# boxes = xywh2xyxy(boxes) | ||
|
||
return boxes | ||
|
||
def rescale_boxes(self, boxes): | ||
# Rescale boxes to original image dimensions | ||
input_shape = np.array( | ||
[self.input_width, self.input_height, self.input_width, self.input_height] | ||
) | ||
boxes = np.divide(boxes, input_shape, dtype=np.float32) | ||
boxes *= np.array( | ||
[self.img_width, self.img_height, self.img_width, self.img_height] | ||
) | ||
return boxes | ||
|
||
def draw_detections( | ||
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4 | ||
): | ||
return draw_detections(image, boxes, scores, class_ids, mask_alpha) | ||
|
||
def get_input_details(self): | ||
model_inputs = self.session.get_inputs() | ||
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))] | ||
|
||
self.input_shape = model_inputs[0].shape | ||
self.input_height = self.input_shape[2] | ||
self.input_width = self.input_shape[3] | ||
|
||
def get_output_details(self): | ||
model_outputs = self.session.get_outputs() | ||
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))] | ||
|
||
|
||
if __name__ == "__main__": | ||
import requests | ||
import tempfile | ||
from huggingface_hub import hf_hub_download | ||
|
||
model_file = hf_hub_download( | ||
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx" | ||
) | ||
|
||
yolov8_detector = YOLOv10(model_file) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: | ||
f.write( | ||
requests.get( | ||
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg" | ||
).content | ||
) | ||
f.seek(0) | ||
img = cv2.imread(f.name) | ||
|
||
# # Detect Objects | ||
combined_image = yolov8_detector.detect_objects(img) | ||
|
||
# Draw detections | ||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL) | ||
cv2.imshow("Output", combined_image) | ||
cv2.waitKey(0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
safetensors==0.4.3 | ||
git+https://github.com/THU-MIG/yolov10.git | ||
opencv-python | ||
twilio | ||
gradio>=5.0,<6.0 | ||
gradio-webrtc==0.0.1 | ||
onnxruntime-gpu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 git+https://github.com/THU-MIG/yolov10.git"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from ultralytics import YOLOv10\n", "\n", "model = YOLOv10.from_pretrained(\"jameslahm/yolov10n\")\n", "\n", "\n", "def yolov10_inference(image, conf_threshold):\n", " width, _ = image.size\n", " import time\n", "\n", " start = time.time()\n", " results = model.predict(source=image, imgsz=width, conf=conf_threshold)\n", " end = time.time()\n", " annotated_image = results[0].plot()\n", " print(\"time\", end - start)\n", " return annotated_image[:, :, ::-1]\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as app:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " <a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection\n", " </h1>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = gr.Image(type=\"pil\", label=\"Image\", sources=\"webcam\")\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", " image.stream(\n", " fn=yolov10_inference,\n", " inputs=[image, conf_threshold],\n", " outputs=[image],\n", " stream_every=0.1,\n", " time_limit=30,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " app.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} | ||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python twilio gradio>=5.0,<6.0 gradio-webrtc==0.0.1 onnxruntime-gpu"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import cv2\n", "from huggingface_hub import hf_hub_download\n", "from gradio_webrtc import WebRTC # type: ignore\n", "from twilio.rest import Client # type: ignore\n", "import os\n", "from inference import YOLOv10 # type: ignore\n", "\n", "model_file = hf_hub_download(\n", " repo_id=\"onnx-community/yolov10n\", filename=\"onnx/model.onnx\"\n", ")\n", "\n", "model = YOLOv10(model_file)\n", "\n", "account_sid = os.environ.get(\"TWILIO_ACCOUNT_SID\")\n", "auth_token = os.environ.get(\"TWILIO_AUTH_TOKEN\")\n", "\n", "if account_sid and auth_token:\n", " client = Client(account_sid, auth_token)\n", "\n", " token = client.tokens.create()\n", "\n", " rtc_configuration = {\n", " \"iceServers\": token.ice_servers,\n", " \"iceTransportPolicy\": \"relay\",\n", " }\n", "else:\n", " rtc_configuration = None\n", "\n", "\n", "def detection(image, conf_threshold=0.3):\n", " image = cv2.resize(image, (model.input_width, model.input_height))\n", " new_image = model.detect_objects(image, conf_threshold)\n", " return cv2.resize(new_image, (500, 500))\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as demo:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " YOLOv10 Webcam Stream (Powered by WebRTC \u26a1\ufe0f)\n", " </h1>\n", " \"\"\"\n", " )\n", " gr.HTML(\n", " \"\"\"\n", " <h3 style='text-align: center'>\n", " <a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>\n", " </h3>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = WebRTC(label=\"Stream\", rtc_configuration=rtc_configuration)\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", "\n", " image.stream(\n", " fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,51 +1,72 @@ | ||||||||||||||||
import gradio as gr | ||||||||||||||||
import cv2 | ||||||||||||||||
from huggingface_hub import hf_hub_download | ||||||||||||||||
from gradio_webrtc import WebRTC # type: ignore | ||||||||||||||||
from twilio.rest import Client # type: ignore | ||||||||||||||||
import os | ||||||||||||||||
from inference import YOLOv10 # type: ignore | ||||||||||||||||
|
||||||||||||||||
from ultralytics import YOLOv10 | ||||||||||||||||
model_file = hf_hub_download( | ||||||||||||||||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx" | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
model = YOLOv10.from_pretrained("jameslahm/yolov10n") | ||||||||||||||||
model = YOLOv10(model_file) | ||||||||||||||||
|
||||||||||||||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID") | ||||||||||||||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN") | ||||||||||||||||
|
||||||||||||||||
def yolov10_inference(image, conf_threshold): | ||||||||||||||||
width, _ = image.size | ||||||||||||||||
import time | ||||||||||||||||
if account_sid and auth_token: | ||||||||||||||||
client = Client(account_sid, auth_token) | ||||||||||||||||
|
||||||||||||||||
start = time.time() | ||||||||||||||||
results = model.predict(source=image, imgsz=width, conf=conf_threshold) | ||||||||||||||||
end = time.time() | ||||||||||||||||
annotated_image = results[0].plot() | ||||||||||||||||
print("time", end - start) | ||||||||||||||||
return annotated_image[:, :, ::-1] | ||||||||||||||||
token = client.tokens.create() | ||||||||||||||||
|
||||||||||||||||
rtc_configuration = { | ||||||||||||||||
"iceServers": token.ice_servers, | ||||||||||||||||
"iceTransportPolicy": "relay", | ||||||||||||||||
} | ||||||||||||||||
else: | ||||||||||||||||
rtc_configuration = None | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def detection(image, conf_threshold=0.3): | ||||||||||||||||
image = cv2.resize(image, (model.input_width, model.input_height)) | ||||||||||||||||
new_image = model.detect_objects(image, conf_threshold) | ||||||||||||||||
return cv2.resize(new_image, (500, 500)) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} | ||||||||||||||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
with gr.Blocks(css=css) as app: | ||||||||||||||||
with gr.Blocks(css=css) as demo: | ||||||||||||||||
gr.HTML( | ||||||||||||||||
""" | ||||||||||||||||
<h1 style='text-align: center'> | ||||||||||||||||
<a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection | ||||||||||||||||
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️) | ||||||||||||||||
</h1> | ||||||||||||||||
""" | ||||||||||||||||
) | ||||||||||||||||
gr.HTML( | ||||||||||||||||
""" | ||||||||||||||||
<h3 style='text-align: center'> | ||||||||||||||||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a> | ||||||||||||||||
</h3> | ||||||||||||||||
""" | ||||||||||||||||
) | ||||||||||||||||
Comment on lines
+49
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest removing to keep the demo simpler, but up to you
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed from the guide but kept in the demo cause I think it's nice to link to the original model source |
||||||||||||||||
with gr.Column(elem_classes=["my-column"]): | ||||||||||||||||
with gr.Group(elem_classes=["my-group"]): | ||||||||||||||||
image = gr.Image(type="pil", label="Image", sources="webcam") | ||||||||||||||||
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration) | ||||||||||||||||
conf_threshold = gr.Slider( | ||||||||||||||||
label="Confidence Threshold", | ||||||||||||||||
minimum=0.0, | ||||||||||||||||
maximum=1.0, | ||||||||||||||||
step=0.05, | ||||||||||||||||
value=0.30, | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
image.stream( | ||||||||||||||||
fn=yolov10_inference, | ||||||||||||||||
inputs=[image, conf_threshold], | ||||||||||||||||
outputs=[image], | ||||||||||||||||
stream_every=0.1, | ||||||||||||||||
time_limit=30, | ||||||||||||||||
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10 | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||
app.launch() | ||||||||||||||||
demo.launch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to move any of these directories to
_frontend_code
directory.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to some custom component fixing clean-up I'll do in a different PR