Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example/traffic_analysis #354

Merged
merged 9 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/detection/annotate.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
## MaskAnnotator

:::supervision.detection.annotate.MaskAnnotator

## TraceAnnotator

:::supervision.detection.annotate.TraceAnnotator
1 change: 1 addition & 0 deletions examples/traffic_analysis/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/
44 changes: 44 additions & 0 deletions examples/traffic_analysis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
## 👋 hello

This script performs traffic flow analysis using YOLOv8, an object-detection method and ByteTrack, a simple yet effective online multi-object tracking method. It uses the supervision package for multiple tasks such as tracking, annotations, etc.

https://github.com/roboflow/supervision/assets/26109316/c9436828-9fbf-4c25-ae8c-60e9c81b3900

## 💻 install

- clone repository and navigate to example directory

```bash
git clone https://github.com/roboflow/supervision.git
cd supervision/examples/traffic_analysis
```

- setup python environment and activate it [optional]

```bash
python3 -m venv venv
source venv/bin/activate
```

- install required dependencies

```bash
pip install -r requirements.txt
```

- download `traffic_analysis.pt` and `traffic_analysis.mov` files

```bash
./setup.sh
```

## ⚙️ run

```bash
python script.py \
--source_weights_path data/traffic_analysis.pt \
--source_video_path data/traffic_analysis.mov \
--confidence_threshold 0.3 \
--iou_threshold 0.5 \
--target_video_path data/traffic_analysis_result.mov
```
4 changes: 4 additions & 0 deletions examples/traffic_analysis/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
supervision
tqdm
ultralytics
gdown
219 changes: 219 additions & 0 deletions examples/traffic_analysis/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import argparse
from typing import Dict, List, Set, Tuple

import cv2
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO

import supervision as sv

COLORS = sv.ColorPalette.default()

ZONE_IN_POLYGONS = [
np.array([[592, 282], [900, 282], [900, 82], [592, 82]]),
np.array([[950, 860], [1250, 860], [1250, 1060], [950, 1060]]),
np.array([[592, 582], [592, 860], [392, 860], [392, 582]]),
np.array([[1250, 282], [1250, 530], [1450, 530], [1450, 282]]),
]

ZONE_OUT_POLYGONS = [
np.array([[950, 282], [1250, 282], [1250, 82], [950, 82]]),
np.array([[592, 860], [900, 860], [900, 1060], [592, 1060]]),
np.array([[592, 282], [592, 550], [392, 550], [392, 282]]),
np.array([[1250, 860], [1250, 560], [1450, 560], [1450, 860]]),
]


class DetectionsManager:
def __init__(self) -> None:
self.tracker_id_to_zone_id: Dict[int, int] = {}
self.counts: Dict[int, Dict[int, Set[int]]] = {}

def update(
self,
detections_all: sv.Detections,
detections_in_zones: List[sv.Detections],
detections_out_zones: List[sv.Detections],
) -> sv.Detections:
for zone_in_id, detections_in_zone in enumerate(detections_in_zones):
for tracker_id in detections_in_zone.tracker_id:
self.tracker_id_to_zone_id.setdefault(tracker_id, zone_in_id)

for zone_out_id, detections_out_zone in enumerate(detections_out_zones):
for tracker_id in detections_out_zone.tracker_id:
if tracker_id in self.tracker_id_to_zone_id:
zone_in_id = self.tracker_id_to_zone_id[tracker_id]
self.counts.setdefault(zone_out_id, {})
self.counts[zone_out_id].setdefault(zone_in_id, set())
self.counts[zone_out_id][zone_in_id].add(tracker_id)

detections_all.class_id = np.vectorize(
lambda x: self.tracker_id_to_zone_id.get(x, -1)
)(detections_all.tracker_id)
return detections_all[detections_all.class_id != -1]


def initiate_polygon_zones(
polygons: List[np.ndarray],
frame_resolution_wh: Tuple[int, int],
triggering_position: sv.Position = sv.Position.CENTER,
) -> List[sv.PolygonZone]:
return [
sv.PolygonZone(
polygon=polygon,
frame_resolution_wh=frame_resolution_wh,
triggering_position=triggering_position,
)
for polygon in polygons
]


class VideoProcessor:
def __init__(
self,
source_weights_path: str,
source_video_path: str,
target_video_path: str = None,
confidence_threshold: float = 0.3,
iou_threshold: float = 0.7,
) -> None:
self.conf_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.source_video_path = source_video_path
self.target_video_path = target_video_path

self.model = YOLO(source_weights_path)
self.tracker = sv.ByteTrack()

self.video_info = sv.VideoInfo.from_video_path(source_video_path)
self.zones_in = initiate_polygon_zones(
ZONE_IN_POLYGONS, self.video_info.resolution_wh, sv.Position.CENTER
)
self.zones_out = initiate_polygon_zones(
ZONE_OUT_POLYGONS, self.video_info.resolution_wh, sv.Position.CENTER
)

self.box_annotator = sv.BoxAnnotator(color=COLORS)
self.trace_annotator = sv.TraceAnnotator(
color=COLORS, position=sv.Position.CENTER, trace_length=100, thickness=2
)
self.detections_manager = DetectionsManager()

def process_video(self):
frame_generator = sv.get_video_frames_generator(
source_path=self.source_video_path
)

if self.target_video_path:
with sv.VideoSink(self.target_video_path, self.video_info) as sink:
for frame in tqdm(frame_generator, total=self.video_info.total_frames):
annotated_frame = self.process_frame(frame)
sink.write_frame(annotated_frame)
else:
for frame in tqdm(frame_generator, total=self.video_info.total_frames):
annotated_frame = self.process_frame(frame)
cv2.imshow("Processed Video", annotated_frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cv2.destroyAllWindows()

def annotate_frame(
self, frame: np.ndarray, detections: sv.Detections
) -> np.ndarray:
annotated_frame = frame.copy()
for i, (zone_in, zone_out) in enumerate(zip(self.zones_in, self.zones_out)):
annotated_frame = sv.draw_polygon(
annotated_frame, zone_in.polygon, COLORS.colors[i]
)
annotated_frame = sv.draw_polygon(
annotated_frame, zone_out.polygon, COLORS.colors[i]
)

labels = [f"#{tracker_id}" for tracker_id in detections.tracker_id]
annotated_frame = self.trace_annotator.annotate(annotated_frame, detections)
annotated_frame = self.box_annotator.annotate(
annotated_frame, detections, labels
)

for zone_out_id, zone_out in enumerate(self.zones_out):
zone_center = sv.get_polygon_center(polygon=zone_out.polygon)
if zone_out_id in self.detections_manager.counts:
counts = self.detections_manager.counts[zone_out_id]
for i, zone_in_id in enumerate(counts):
count = len(self.detections_manager.counts[zone_out_id][zone_in_id])
text_anchor = sv.Point(x=zone_center.x, y=zone_center.y + 40 * i)
annotated_frame = sv.draw_text(
scene=annotated_frame,
text=str(count),
text_anchor=text_anchor,
background_color=COLORS.colors[zone_in_id],
)

return annotated_frame

def process_frame(self, frame: np.ndarray) -> np.ndarray:
results = self.model(
frame, verbose=False, conf=self.conf_threshold, iou=self.iou_threshold
)[0]
detections = sv.Detections.from_ultralytics(results)
detections.class_id = np.zeros(len(detections))
detections = self.tracker.update_with_detections(detections)

detections_in_zones = []
detections_out_zones = []

for i, (zone_in, zone_out) in enumerate(zip(self.zones_in, self.zones_out)):
detections_in_zone = detections[zone_in.trigger(detections=detections)]
detections_in_zones.append(detections_in_zone)
detections_out_zone = detections[zone_out.trigger(detections=detections)]
detections_out_zones.append(detections_out_zone)

detections = self.detections_manager.update(
detections, detections_in_zones, detections_out_zones
)
return self.annotate_frame(frame, detections)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Traffic Flow Analysis with YOLO and ByteTrack"
)

parser.add_argument(
"--source_weights_path",
required=True,
help="Path to the source weights file",
type=str,
)
parser.add_argument(
"--source_video_path",
required=True,
help="Path to the source video file",
type=str,
)
parser.add_argument(
"--target_video_path",
default=None,
help="Path to the target video file (output)",
type=str,
)
parser.add_argument(
"--confidence_threshold",
default=0.3,
help="Confidence threshold for the model",
type=float,
)
parser.add_argument(
"--iou_threshold", default=0.7, help="IOU threshold for the model", type=float
)

args = parser.parse_args()
processor = VideoProcessor(
source_weights_path=args.source_weights_path,
source_video_path=args.source_video_path,
target_video_path=args.target_video_path,
confidence_threshold=args.confidence_threshold,
iou_threshold=args.iou_threshold,
)
processor.process_video()
17 changes: 17 additions & 0 deletions examples/traffic_analysis/script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

# Get the directory where the script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

# Check if 'data' directory does not exist and then create it
if [[ ! -e $DIR/data ]]; then
mkdir "$DIR/data"
else
echo "'data' directory already exists."
fi

# Download the traffic_analysis.mov file from Google Drive
gdown -O "$DIR/data/traffic_analysis.mov" "https://drive.google.com/uc?id=1qadBd7lgpediafCpL_yedGjQPk-FLK-W"

# Download the traffic_analysis.pt file from Google Drive
gdown -O "$DIR/data/traffic_analysis.pt" "https://drive.google.com/uc?id=1y-IfToCjRXa3ZdC1JpnKRopC7mcQW-5z"
2 changes: 1 addition & 1 deletion supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ClassificationDataset,
DetectionDataset,
)
from supervision.detection.annotate import BoxAnnotator, MaskAnnotator
from supervision.detection.annotate import BoxAnnotator, MaskAnnotator, TraceAnnotator
from supervision.detection.core import Detections
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
from supervision.detection.tools.inference_slicer import InferenceSlicer
Expand Down
Loading