-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7504773
commit 924cf30
Showing
6 changed files
with
433 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import time | ||
import cv2 | ||
import numpy as np | ||
import onnxruntime # type: ignore | ||
|
||
from utils import draw_detections # type: ignore | ||
|
||
|
||
class YOLOv10: | ||
def __init__(self, path): | ||
# Initialize model | ||
self.initialize_model(path) | ||
|
||
def __call__(self, image): | ||
return self.detect_objects(image) | ||
|
||
def initialize_model(self, path): | ||
self.session = onnxruntime.InferenceSession( | ||
path, providers=onnxruntime.get_available_providers() | ||
) | ||
# Get model info | ||
self.get_input_details() | ||
self.get_output_details() | ||
|
||
def detect_objects(self, image, conf_threshold=0.3): | ||
input_tensor = self.prepare_input(image) | ||
|
||
# Perform inference on the image | ||
new_image = self.inference(image, input_tensor, conf_threshold) | ||
|
||
return new_image | ||
|
||
def prepare_input(self, image): | ||
self.img_height, self.img_width = image.shape[:2] | ||
|
||
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
# Resize input image | ||
input_img = cv2.resize(input_img, (self.input_width, self.input_height)) | ||
|
||
# Scale input pixel values to 0 to 1 | ||
input_img = input_img / 255.0 | ||
input_img = input_img.transpose(2, 0, 1) | ||
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) | ||
|
||
return input_tensor | ||
|
||
def inference(self, image, input_tensor, conf_threshold=0.3): | ||
start = time.perf_counter() | ||
outputs = self.session.run( | ||
self.output_names, {self.input_names[0]: input_tensor} | ||
) | ||
|
||
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms") | ||
( | ||
boxes, | ||
scores, | ||
class_ids, | ||
) = self.process_output(outputs, conf_threshold) | ||
return self.draw_detections(image, boxes, scores, class_ids) | ||
|
||
def process_output(self, output, conf_threshold=0.3): | ||
predictions = np.squeeze(output[0]) | ||
|
||
# Filter out object confidence scores below threshold | ||
scores = predictions[:, 4] | ||
predictions = predictions[scores > conf_threshold, :] | ||
scores = scores[scores > conf_threshold] | ||
|
||
if len(scores) == 0: | ||
return [], [], [] | ||
|
||
# Get the class with the highest confidence | ||
class_ids = predictions[:, 5].astype(int) | ||
|
||
# Get bounding boxes for each object | ||
boxes = self.extract_boxes(predictions) | ||
|
||
return boxes, scores, class_ids | ||
|
||
def extract_boxes(self, predictions): | ||
# Extract boxes from predictions | ||
boxes = predictions[:, :4] | ||
|
||
# Scale boxes to original image dimensions | ||
boxes = self.rescale_boxes(boxes) | ||
|
||
# Convert boxes to xyxy format | ||
# boxes = xywh2xyxy(boxes) | ||
|
||
return boxes | ||
|
||
def rescale_boxes(self, boxes): | ||
# Rescale boxes to original image dimensions | ||
input_shape = np.array( | ||
[self.input_width, self.input_height, self.input_width, self.input_height] | ||
) | ||
boxes = np.divide(boxes, input_shape, dtype=np.float32) | ||
boxes *= np.array( | ||
[self.img_width, self.img_height, self.img_width, self.img_height] | ||
) | ||
return boxes | ||
|
||
def draw_detections( | ||
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4 | ||
): | ||
return draw_detections(image, boxes, scores, class_ids, mask_alpha) | ||
|
||
def get_input_details(self): | ||
model_inputs = self.session.get_inputs() | ||
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))] | ||
|
||
self.input_shape = model_inputs[0].shape | ||
self.input_height = self.input_shape[2] | ||
self.input_width = self.input_shape[3] | ||
|
||
def get_output_details(self): | ||
model_outputs = self.session.get_outputs() | ||
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))] | ||
|
||
|
||
if __name__ == "__main__": | ||
import requests | ||
import tempfile | ||
from huggingface_hub import hf_hub_download | ||
|
||
model_file = hf_hub_download( | ||
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx" | ||
) | ||
|
||
yolov8_detector = YOLOv10(model_file) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: | ||
f.write( | ||
requests.get( | ||
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg" | ||
).content | ||
) | ||
f.seek(0) | ||
img = cv2.imread(f.name) | ||
|
||
# # Detect Objects | ||
combined_image = yolov8_detector.detect_objects(img) | ||
|
||
# Draw detections | ||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL) | ||
cv2.imshow("Output", combined_image) | ||
cv2.waitKey(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
safetensors==0.4.3 | ||
git+https://github.com/THU-MIG/yolov10.git | ||
opencv-python | ||
twilio | ||
gradio>=5.0,<6.0 | ||
gradio-webrtc==0.0.1 | ||
onnxruntime-gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 git+https://github.com/THU-MIG/yolov10.git"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from ultralytics import YOLOv10\n", "\n", "model = YOLOv10.from_pretrained(\"jameslahm/yolov10n\")\n", "\n", "\n", "def yolov10_inference(image, conf_threshold):\n", " width, _ = image.size\n", " import time\n", "\n", " start = time.time()\n", " results = model.predict(source=image, imgsz=width, conf=conf_threshold)\n", " end = time.time()\n", " annotated_image = results[0].plot()\n", " print(\"time\", end - start)\n", " return annotated_image[:, :, ::-1]\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as app:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " <a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection\n", " </h1>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = gr.Image(type=\"pil\", label=\"Image\", sources=\"webcam\")\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", " image.stream(\n", " fn=yolov10_inference,\n", " inputs=[image, conf_threshold],\n", " outputs=[image],\n", " stream_every=0.1,\n", " time_limit=30,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " app.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} | ||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python twilio gradio>=5.0,<6.0 gradio-webrtc==0.0.1 onnxruntime-gpu"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import cv2\n", "from huggingface_hub import hf_hub_download\n", "from gradio_webrtc import WebRTC # type: ignore\n", "from twilio.rest import Client # type: ignore\n", "import os\n", "from inference import YOLOv10 # type: ignore\n", "\n", "model_file = hf_hub_download(\n", " repo_id=\"onnx-community/yolov10n\", filename=\"onnx/model.onnx\"\n", ")\n", "\n", "model = YOLOv10(model_file)\n", "\n", "account_sid = os.environ.get(\"TWILIO_ACCOUNT_SID\")\n", "auth_token = os.environ.get(\"TWILIO_AUTH_TOKEN\")\n", "\n", "if account_sid and auth_token:\n", " client = Client(account_sid, auth_token)\n", "\n", " token = client.tokens.create()\n", "\n", " rtc_configuration = {\n", " \"iceServers\": token.ice_servers,\n", " \"iceTransportPolicy\": \"relay\",\n", " }\n", "else:\n", " rtc_configuration = None\n", "\n", "\n", "def detection(image, conf_threshold=0.3):\n", " image = cv2.resize(image, (model.input_width, model.input_height))\n", " new_image = model.detect_objects(image, conf_threshold)\n", " return cv2.resize(new_image, (500, 500))\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as demo:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " YOLOv10 Webcam Stream (Powered by WebRTC \u26a1\ufe0f)\n", " </h1>\n", " \"\"\"\n", " )\n", " gr.HTML(\n", " \"\"\"\n", " <h3 style='text-align: center'>\n", " <a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>\n", " </h3>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = WebRTC(label=\"Stream\", rtc_configuration=rtc_configuration)\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", "\n", " image.stream(\n", " fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,72 @@ | ||
import gradio as gr | ||
import cv2 | ||
from huggingface_hub import hf_hub_download | ||
from gradio_webrtc import WebRTC # type: ignore | ||
from twilio.rest import Client # type: ignore | ||
import os | ||
from inference import YOLOv10 # type: ignore | ||
|
||
from ultralytics import YOLOv10 | ||
model_file = hf_hub_download( | ||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx" | ||
) | ||
|
||
model = YOLOv10.from_pretrained("jameslahm/yolov10n") | ||
model = YOLOv10(model_file) | ||
|
||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID") | ||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN") | ||
|
||
def yolov10_inference(image, conf_threshold): | ||
width, _ = image.size | ||
import time | ||
if account_sid and auth_token: | ||
client = Client(account_sid, auth_token) | ||
|
||
start = time.time() | ||
results = model.predict(source=image, imgsz=width, conf=conf_threshold) | ||
end = time.time() | ||
annotated_image = results[0].plot() | ||
print("time", end - start) | ||
return annotated_image[:, :, ::-1] | ||
token = client.tokens.create() | ||
|
||
rtc_configuration = { | ||
"iceServers": token.ice_servers, | ||
"iceTransportPolicy": "relay", | ||
} | ||
else: | ||
rtc_configuration = None | ||
|
||
|
||
def detection(image, conf_threshold=0.3): | ||
image = cv2.resize(image, (model.input_width, model.input_height)) | ||
new_image = model.detect_objects(image, conf_threshold) | ||
return cv2.resize(new_image, (500, 500)) | ||
|
||
|
||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} | ||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | ||
|
||
|
||
with gr.Blocks(css=css) as app: | ||
with gr.Blocks(css=css) as demo: | ||
gr.HTML( | ||
""" | ||
<h1 style='text-align: center'> | ||
<a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection | ||
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️) | ||
</h1> | ||
""" | ||
) | ||
gr.HTML( | ||
""" | ||
<h3 style='text-align: center'> | ||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a> | ||
</h3> | ||
""" | ||
) | ||
with gr.Column(elem_classes=["my-column"]): | ||
with gr.Group(elem_classes=["my-group"]): | ||
image = gr.Image(type="pil", label="Image", sources="webcam") | ||
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration) | ||
conf_threshold = gr.Slider( | ||
label="Confidence Threshold", | ||
minimum=0.0, | ||
maximum=1.0, | ||
step=0.05, | ||
value=0.30, | ||
) | ||
|
||
image.stream( | ||
fn=yolov10_inference, | ||
inputs=[image, conf_threshold], | ||
outputs=[image], | ||
stream_every=0.1, | ||
time_limit=30, | ||
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10 | ||
) | ||
|
||
if __name__ == "__main__": | ||
app.launch() | ||
demo.launch() |
Oops, something went wrong.