Skip to content

Commit

Permalink
Feature/yolov9 (#21)
Browse files Browse the repository at this point in the history
* WIP Add YOLOv9ONNX

* Add YOLOv9

* Fix bug
  • Loading branch information
nmhaddad authored Feb 25, 2024
1 parent f0a90bc commit 42b10d3
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
## Known Issues:
- None

## v1.1.0 - Nate Haddad, 2/25/2024
- Add `YOLOv9ONNX` to `detectors`

## v1.0.0 - Nate Haddad, 2/19/2024
- Add database option to tracker
- Add `SQLDatabase` database class
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Fast-Track 🚀 Real-Time Object Tracking Pipelines

Installable Python package for object tracking pipelines with YOLO-NAS, YOLOv8, and YOLOv7 object detectors and BYTETracker object tracking with support for SQL database servers.
Installable Python package for object tracking pipelines with YOLOv9, YOLO-NAS, YOLOv8, and YOLOv7 object detectors and BYTETracker object tracking with support for SQL database servers.

[Try it out now with Gradio](#run-the-demo).

Expand Down Expand Up @@ -67,3 +67,5 @@ Author: Nate Haddad - nhaddad2112[at]gmail[dot]com
[4] Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Weng, Fucheng and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang; "ByteTrack: Multi-Object Tracking by Associating Every Detection Box"; https://github.com/ifzhang; 2022; [Online]. Available: https://github.com/ifzhang/ByteTrack

[5] Aharon, Shay and Louis-Dupont and Ofri Masad and Yurkova, Kate and Lotem Fridman and Lkdci and Khvedchenya, Eugene and Rubin, Ran and Bagrov, Natan and Tymchenko, Borys and Keren, Tomer and Zhilko, Alexander and Eran-Deci; "Super-Gradients"; https://github.com/Deci-AI/super-gradients; 2023; [Online]. Available: https://github.com/Deci-AI/super-gradients

[6] Wang, Chien-Yao and Liao, Hong-Yuan Mark; "YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information"; https://github.com/WongKinYiu/yolov9; 2024; [Online]. Available: https://github.com/WongKinYiu/yolov9
2 changes: 1 addition & 1 deletion config/coco.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ outfile: nba-5-13_out.mp4

detector:
visualize: False
weights_path: yolo_nas_s
weights_path: models/yolov9-c.onnx


tracker:
Expand Down
5 changes: 3 additions & 2 deletions fast_track/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from .object_detector import ObjectDetector
from .object_detector_onnx import ObjectDetectorONNX
from .third_party.yolov7.yolov7 import YOLOv7
from .third_party.yolov7 import YOLOv7ONNX
from .third_party.yolov8 import YOLOv8, YOLOv8ONNX
from .third_party.yolo_nas.yolo_nas import YOLONAS
from .third_party.yolo_nas import YOLONAS
from .third_party.yolov9 import YOLOv9ONNX
from .util import get_detector
3 changes: 3 additions & 0 deletions fast_track/detectors/third_party/yolo_nas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
""" Initializes the yolo_nas module """

from .yolo_nas import YOLONAS
3 changes: 3 additions & 0 deletions fast_track/detectors/third_party/yolov7/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
""" Initializes the yolov7 module """

from .yolov7_onnx import YOLOv7ONNX
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" YOLOv7 ONNX detector wrapper """
""" YOLOv7ONNX detector wrapper """

from typing import Tuple, List

Expand All @@ -7,8 +7,8 @@
from ...object_detector_onnx import ObjectDetectorONNX


class YOLOv7(ObjectDetectorONNX):
""" YOLOv7 ONNX detector.
class YOLOv7ONNX(ObjectDetectorONNX):
""" YOLOv7ONNX detector.
Attributes:
weights_path: path to pretrained weights.
Expand Down
2 changes: 1 addition & 1 deletion fast_track/detectors/third_party/yolov8/yolov8_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def postprocess(self, tensor: np.ndarray) -> Tuple[list, list, list]:
predictions = ops.non_max_suppression(torch.tensor(tensor[0]),
conf_thres=self.conf_thresh,
iou_thres=self.iou_thresh,
classes=self.classes,
classes=len(self.classes),
agnostic=self.agnostic,
multi_label=self.multi_label,
labels=self.labels,
Expand Down
3 changes: 3 additions & 0 deletions fast_track/detectors/third_party/yolov9/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
""" Initializes the yolov9 module """

from .yolov9_onnx import YOLOv9ONNX
101 changes: 101 additions & 0 deletions fast_track/detectors/third_party/yolov9/yolov9_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
""" YOLOv9ONNX detector wrapper """

from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import ultralytics.utils.ops as ops

from ...object_detector_onnx import ObjectDetectorONNX


class YOLOv9ONNX(ObjectDetectorONNX):
""" YOLOv9ONNX detector.
Attributes:
weights_path: path to pretrained weights.
providers: flags for CUDA execution.
sessions: ORT session.
input_names: model input names.
input_shape: input shape (B,C,H,W)
input_height: input height.
input_width: input width.
output_names: model output names.
conf_thresh: The confidence threshold below which boxes will be filtered out. Valid values are
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
agnostic:If True, the model is agnostic to the number of classes, and all classes will be considered as one.
multi_label: If True, each box may have multiple labels.
labels: A list of lists, where each inner list contains the apriori labels for a given image. The list should
be in the format output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det: The maximum number of boxes to keep after NMS.
nm: The number of masks output by the model.
"""

def __init__(self,
weights_path: str,
names: List[str],
image_shape: Tuple[int, int],
visualize: bool = False,
conf_thres: float = 0.25,
iou_thres: float = 0.45,
classes: Optional[List[int]] = None,
agnostic: bool = False,
multi_label: bool = False,
labels: List[List[Union[int, float, torch.Tensor]]] = (),
max_det: int = 300):
""" Init YOLOv9ONNX objects with given parameters.
Args:
weights_path. path to pretrained weights.
names: a list of names for classes.
image_shape: shape of input images.
visualize: bool to visualize output or not.
conf_thresh: The confidence threshold below which boxes will be filtered out. Valid values are
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
agnostic: If True, the model is agnostic to the number of classes, and all classes will be considered
as one.
multi_label:If True, each box may have multiple labels.
labels: A list of lists, where each inner list contains the apriori labels for a given image. The list
should be in the format output by a dataloader, with each label being a tuple of
(class_index, x1, y1, x2, y2).
max_det: The maximum number of boxes to keep after NMS.
"""
super().__init__(weights_path, names, image_shape, visualize)
self.conf_thresh = conf_thres
self.iou_thresh = iou_thres
self.classes = classes
self.agnostic = agnostic
self.multi_label = multi_label
self.labels = labels
self.max_det = max_det

def postprocess(self,
tensor: np.ndarray) -> Tuple[list, list, list]:
""" Postprocesses output.
Args:
tensor: output tensor from ONNX session.
Returns:
Postprocessed output as a tuple of class_ids, scores, and boxes.
"""
predictions = ops.non_max_suppression(torch.tensor(tensor[0]),
conf_thres=self.conf_thresh,
iou_thres=self.iou_thresh,
classes=len(self.classes),
agnostic=self.agnostic,
multi_label=self.multi_label,
labels=self.labels,
max_det=self.max_det)
boxes = predictions[0][:, :4].int().numpy()
class_ids = predictions[0][:, 5:6].int().flatten().tolist()
scores = predictions[0][:, 4:5].flatten().tolist()
if len(scores) == 0:
return [], [], []
boxes = self.rescale_boxes(boxes)
return class_ids, scores, boxes
13 changes: 10 additions & 3 deletions fast_track/detectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .object_detector import ObjectDetector
from .third_party.yolo_nas.yolo_nas import YOLONAS
from .third_party.yolov8 import YOLOv8, YOLOv8ONNX
from .third_party.yolov7.yolov7 import YOLOv7
from .third_party.yolov7 import YOLOv7ONNX
from .third_party.yolov9 import YOLOv9ONNX


MODELS = {
Expand All @@ -19,7 +20,8 @@
"YOLOv8 L": "yolov8_l",
"YOLOv8 X": "yolov8_x",
"YOLOv8": "yolov8_custom",
"YOLOv7": "yolov7_custom"
"YOLOv7": "yolov7_custom",
"YOLOv9": "yolov9_custom"
}


Expand Down Expand Up @@ -61,7 +63,12 @@ def get_detector(weights_path: str,
image_shape=image_shape,
**detector_params)
elif detector_type.startswith("yolov7"):
return YOLOv7(weights_path=weights_path,
return YOLOv7ONNX(weights_path=weights_path,
names=names,
image_shape=image_shape,
**detector_params)
elif detector_type.startswith("yolov9"):
return YOLOv9ONNX(weights_path=weights_path,
names=names,
image_shape=image_shape,
**detector_params)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fast_track"
version = "1.0.0"
version = "1.1.0"
description = "Object detection and tracking pipeline"
readme = "README.md"
keywords = [
Expand Down
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dotenv import load_dotenv

from fast_track import Pipeline
from fast_track.detectors import YOLONAS
from fast_track.detectors import YOLOv9ONNX
from fast_track.trackers import BYTETracker
from fast_track.databases import SQLDatabase

Expand All @@ -16,7 +16,7 @@
config = yaml.safe_load(f)

camera = cv2.VideoCapture(config['data_path'])
detector = YOLONAS(**config['detector'], names=config['names'], image_shape=(camera.get(3), camera.get(4)))
detector = YOLOv9ONNX(**config['detector'], names=config['names'], image_shape=(camera.get(3), camera.get(4)))
tracker = BYTETracker(**config['tracker'], names=config['names'])
database = SQLDatabase(**config["db"], class_names=config['names'])

Expand Down

0 comments on commit 42b10d3

Please sign in to comment.