Skip to content

Commit

Permalink
Add code
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Sep 26, 2024
1 parent 7504773 commit 924cf30
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 23 deletions.
148 changes: 148 additions & 0 deletions demo/yolov10_webcam_stream/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import time
import cv2
import numpy as np
import onnxruntime # type: ignore

from utils import draw_detections # type: ignore


class YOLOv10:
def __init__(self, path):
# Initialize model
self.initialize_model(path)

def __call__(self, image):
return self.detect_objects(image)

def initialize_model(self, path):
self.session = onnxruntime.InferenceSession(
path, providers=onnxruntime.get_available_providers()
)
# Get model info
self.get_input_details()
self.get_output_details()

def detect_objects(self, image, conf_threshold=0.3):
input_tensor = self.prepare_input(image)

# Perform inference on the image
new_image = self.inference(image, input_tensor, conf_threshold)

return new_image

def prepare_input(self, image):
self.img_height, self.img_width = image.shape[:2]

input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Resize input image
input_img = cv2.resize(input_img, (self.input_width, self.input_height))

# Scale input pixel values to 0 to 1
input_img = input_img / 255.0
input_img = input_img.transpose(2, 0, 1)
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)

return input_tensor

def inference(self, image, input_tensor, conf_threshold=0.3):
start = time.perf_counter()
outputs = self.session.run(
self.output_names, {self.input_names[0]: input_tensor}
)

print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
(
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids)

def process_output(self, output, conf_threshold=0.3):
predictions = np.squeeze(output[0])

# Filter out object confidence scores below threshold
scores = predictions[:, 4]
predictions = predictions[scores > conf_threshold, :]
scores = scores[scores > conf_threshold]

if len(scores) == 0:
return [], [], []

# Get the class with the highest confidence
class_ids = predictions[:, 5].astype(int)

# Get bounding boxes for each object
boxes = self.extract_boxes(predictions)

return boxes, scores, class_ids

def extract_boxes(self, predictions):
# Extract boxes from predictions
boxes = predictions[:, :4]

# Scale boxes to original image dimensions
boxes = self.rescale_boxes(boxes)

# Convert boxes to xyxy format
# boxes = xywh2xyxy(boxes)

return boxes

def rescale_boxes(self, boxes):
# Rescale boxes to original image dimensions
input_shape = np.array(
[self.input_width, self.input_height, self.input_width, self.input_height]
)
boxes = np.divide(boxes, input_shape, dtype=np.float32)
boxes *= np.array(
[self.img_width, self.img_height, self.img_width, self.img_height]
)
return boxes

def draw_detections(
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
):
return draw_detections(image, boxes, scores, class_ids, mask_alpha)

def get_input_details(self):
model_inputs = self.session.get_inputs()
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]

self.input_shape = model_inputs[0].shape
self.input_height = self.input_shape[2]
self.input_width = self.input_shape[3]

def get_output_details(self):
model_outputs = self.session.get_outputs()
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]


if __name__ == "__main__":
import requests
import tempfile
from huggingface_hub import hf_hub_download

model_file = hf_hub_download(
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
)

yolov8_detector = YOLOv10(model_file)

with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
f.write(
requests.get(
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
).content
)
f.seek(0)
img = cv2.imread(f.name)

# # Detect Objects
combined_image = yolov8_detector.detect_objects(img)

# Draw detections
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", combined_image)
cv2.waitKey(0)
6 changes: 5 additions & 1 deletion demo/yolov10_webcam_stream/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
safetensors==0.4.3
git+https://github.com/THU-MIG/yolov10.git
opencv-python
twilio
gradio>=5.0,<6.0
gradio-webrtc==0.0.1
onnxruntime-gpu
2 changes: 1 addition & 1 deletion demo/yolov10_webcam_stream/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 git+https://github.com/THU-MIG/yolov10.git"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from ultralytics import YOLOv10\n", "\n", "model = YOLOv10.from_pretrained(\"jameslahm/yolov10n\")\n", "\n", "\n", "def yolov10_inference(image, conf_threshold):\n", " width, _ = image.size\n", " import time\n", "\n", " start = time.time()\n", " results = model.predict(source=image, imgsz=width, conf=conf_threshold)\n", " end = time.time()\n", " annotated_image = results[0].plot()\n", " print(\"time\", end - start)\n", " return annotated_image[:, :, ::-1]\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as app:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " <a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection\n", " </h1>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = gr.Image(type=\"pil\", label=\"Image\", sources=\"webcam\")\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", " image.stream(\n", " fn=yolov10_inference,\n", " inputs=[image, conf_threshold],\n", " outputs=[image],\n", " stream_every=0.1,\n", " time_limit=30,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " app.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: yolov10_webcam_stream"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python twilio gradio>=5.0,<6.0 gradio-webrtc==0.0.1 onnxruntime-gpu"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import cv2\n", "from huggingface_hub import hf_hub_download\n", "from gradio_webrtc import WebRTC # type: ignore\n", "from twilio.rest import Client # type: ignore\n", "import os\n", "from inference import YOLOv10 # type: ignore\n", "\n", "model_file = hf_hub_download(\n", " repo_id=\"onnx-community/yolov10n\", filename=\"onnx/model.onnx\"\n", ")\n", "\n", "model = YOLOv10(model_file)\n", "\n", "account_sid = os.environ.get(\"TWILIO_ACCOUNT_SID\")\n", "auth_token = os.environ.get(\"TWILIO_AUTH_TOKEN\")\n", "\n", "if account_sid and auth_token:\n", " client = Client(account_sid, auth_token)\n", "\n", " token = client.tokens.create()\n", "\n", " rtc_configuration = {\n", " \"iceServers\": token.ice_servers,\n", " \"iceTransportPolicy\": \"relay\",\n", " }\n", "else:\n", " rtc_configuration = None\n", "\n", "\n", "def detection(image, conf_threshold=0.3):\n", " image = cv2.resize(image, (model.input_width, model.input_height))\n", " new_image = model.detect_objects(image, conf_threshold)\n", " return cv2.resize(new_image, (500, 500))\n", "\n", "\n", "css = \"\"\".my-group {max-width: 600px !important; max-height: 600 !important;}\n", " .my-column {display: flex !important; justify-content: center !important; align-items: center !important};\"\"\"\n", "\n", "\n", "with gr.Blocks(css=css) as demo:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " YOLOv10 Webcam Stream (Powered by WebRTC \u26a1\ufe0f)\n", " </h1>\n", " \"\"\"\n", " )\n", " gr.HTML(\n", " \"\"\"\n", " <h3 style='text-align: center'>\n", " <a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>\n", " </h3>\n", " \"\"\"\n", " )\n", " with gr.Column(elem_classes=[\"my-column\"]):\n", " with gr.Group(elem_classes=[\"my-group\"]):\n", " image = WebRTC(label=\"Stream\", rtc_configuration=rtc_configuration)\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", "\n", " image.stream(\n", " fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
61 changes: 41 additions & 20 deletions demo/yolov10_webcam_stream/run.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,72 @@
import gradio as gr
import cv2
from huggingface_hub import hf_hub_download
from gradio_webrtc import WebRTC # type: ignore
from twilio.rest import Client # type: ignore
import os
from inference import YOLOv10 # type: ignore

from ultralytics import YOLOv10
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)

model = YOLOv10.from_pretrained("jameslahm/yolov10n")
model = YOLOv10(model_file)

account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")

def yolov10_inference(image, conf_threshold):
width, _ = image.size
import time
if account_sid and auth_token:
client = Client(account_sid, auth_token)

start = time.time()
results = model.predict(source=image, imgsz=width, conf=conf_threshold)
end = time.time()
annotated_image = results[0].plot()
print("time", end - start)
return annotated_image[:, :, ::-1]
token = client.tokens.create()

rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None


def detection(image, conf_threshold=0.3):
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))


css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""


with gr.Blocks(css=css) as app:
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
<a href='https://github.com/THU-MIG/yolov10' target='_blank'>YOLO V10</a> Webcam Stream Object Detection
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = gr.Image(type="pil", label="Image", sources="webcam")
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)

image.stream(
fn=yolov10_inference,
inputs=[image, conf_threshold],
outputs=[image],
stream_every=0.1,
time_limit=30,
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)

if __name__ == "__main__":
app.launch()
demo.launch()
Loading

0 comments on commit 924cf30

Please sign in to comment.