Skip to content

Commit

Permalink
args from config instead of hardcoded
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel Broström committed Sep 23, 2024
1 parent 106b530 commit 5c44e69
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 57 deletions.
1 change: 0 additions & 1 deletion boxmot/configs/botsort.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
appearance_thresh: 0.25
cmc_method: ecc
frame_rate: 30
lambda_: 0.98
match_thresh: 0.8
new_track_thresh: 0.7
proximity_thresh: 0.5
Expand Down
2 changes: 0 additions & 2 deletions boxmot/configs/hybridsort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# HOTA, MOTA, IDF1: [40.35]
TCM_first_step_weight: 0.2866529225304586
asso_func: hmiou
conf: 0.44745176349923044
delta_t: 5
det_thresh: 0.12442660055370669
inertia: 0.369525477649008
iou_thresh: 0.39555224612193407
longterm_reid_weight: 0.0509704360503877
max_age: 30
min_hits: 1
Expand Down
1 change: 0 additions & 1 deletion boxmot/configs/ocsort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ asso_func: iou
delta_t: 3
det_thresh: 0.6
inertia: 0.2
iou_thresh: 0.3
max_age: 30
min_hits: 3
use_byte: false
Expand Down
1 change: 0 additions & 1 deletion boxmot/configs/strongsort.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
ecc: false
ema_alpha: 0.9
max_age: 30
max_cos_dist: 0.4
Expand Down
85 changes: 45 additions & 40 deletions boxmot/tracker_zoo.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,69 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license



from types import SimpleNamespace
import yaml
from boxmot.utils import BOXMOT

def get_tracker_config(tracker_type):
tracking_config = BOXMOT / 'configs' / (tracker_type + '.yaml')
return tracking_config
"""Returns the path to the tracker configuration file."""
return BOXMOT / 'configs' / f'{tracker_type}.yaml'

def create_tracker(tracker_type, tracker_config=None, reid_weights=None, device=None, half=None, per_class=None, evolve_param_dict=None):
# Load the configuration from file or use the provided dictionary
"""
Creates and returns an instance of the specified tracker type.
Parameters:
- tracker_type: The type of the tracker (e.g., 'strongsort', 'ocsort').
- tracker_config: Path to the tracker configuration file.
- reid_weights: Weights for ReID (re-identification).
- device: Device to run the tracker on (e.g., 'cpu', 'cuda').
- half: Boolean indicating whether to use half-precision.
- per_class: Boolean for class-specific tracking (optional).
- evolve_param_dict: A dictionary of parameters for evolving the tracker.
Returns:
- An instance of the selected tracker.
"""

# Load configuration from file or use provided dictionary
if evolve_param_dict is None:
with open(tracker_config, "r") as f:
tracker_args = yaml.load(f.read(), Loader=yaml.FullLoader)
tracker_args = yaml.load(f, Loader=yaml.FullLoader)
else:
tracker_args = evolve_param_dict

# Arguments specific to ReID models
reid_args = {
'reid_weights': reid_weights,
'device': device,
'half': half,
'per_class': per_class
}

if tracker_type == 'strongsort':
from boxmot.trackers.strongsort.strong_sort import StrongSORT
tracker_args.update(reid_args)
tracker_args.pop('per_class')
return StrongSORT(**tracker_args)

elif tracker_type == 'ocsort':
from boxmot.trackers.ocsort.ocsort import OCSort
return OCSort(**tracker_args)

elif tracker_type == 'bytetrack':
from boxmot.trackers.bytetrack.byte_tracker import BYTETracker
return BYTETracker(**tracker_args)

elif tracker_type == 'botsort':
from boxmot.trackers.botsort.bot_sort import BoTSORT
tracker_args.update(reid_args)
return BoTSORT(**tracker_args)

elif tracker_type == 'deepocsort':
from boxmot.trackers.deepocsort.deep_ocsort import DeepOCSort
tracker_args.update(reid_args)
return DeepOCSort(**tracker_args)
# Map tracker types to their corresponding classes
tracker_mapping = {
'strongsort': 'boxmot.trackers.strongsort.strong_sort.StrongSORT',
'ocsort': 'boxmot.trackers.ocsort.ocsort.OCSort',
'bytetrack': 'boxmot.trackers.bytetrack.byte_tracker.BYTETracker',
'botsort': 'boxmot.trackers.botsort.bot_sort.BoTSORT',
'deepocsort': 'boxmot.trackers.deepocsort.deep_ocsort.DeepOCSort',
'hybridsort': 'boxmot.trackers.hybridsort.hybridsort.HybridSORT',
'imprassoc': 'boxmot.trackers.imprassoc.impr_assoc_tracker.ImprAssocTrack'
}

elif tracker_type == 'hybridsort':
from boxmot.trackers.hybridsort.hybridsort import HybridSORT
tracker_args.update(reid_args)
return HybridSORT(**tracker_args)
# Check if the tracker type exists in the mapping
if tracker_type not in tracker_mapping:
print('Error: No such tracker found.')
exit()

elif tracker_type == 'imprassoc':
from boxmot.trackers.imprassoc.impr_assoc_tracker import ImprAssocTrack
# Dynamically import and instantiate the correct tracker class
module_path, class_name = tracker_mapping[tracker_type].rsplit('.', 1)
tracker_class = getattr(__import__(module_path, fromlist=[class_name]), class_name)

# For specific trackers, update tracker arguments with ReID parameters
if tracker_type in ['strongsort', 'botsort', 'deepocsort', 'hybridsort', 'imprassoc']:
tracker_args.update(reid_args)
return ImprAssocTrack(**tracker_args)
if tracker_type == 'strongsort':
tracker_args.pop('per_class') # Remove per_class if not needed

else:
print('No such tracker')
exit()
# Return the instantiated tracker class with arguments
return tracker_class(**tracker_args)
6 changes: 3 additions & 3 deletions boxmot/trackers/botsort/bot_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ class BoTSORT(BaseTracker):
"""
def __init__(
self,
model_weights: Path,
reid_weights: Path,
device: torch.device,
fp16: bool,
half: bool,
per_class: bool = False,
track_high_thresh: float = 0.5,
track_low_thresh: float = 0.1,
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(
self.with_reid = with_reid
if self.with_reid:
self.model = ReidAutoBackend(
weights=model_weights, device=device, half=fp16
weights=reid_weights, device=device, half=half
).model

self.cmc = SOF()
Expand Down
6 changes: 3 additions & 3 deletions boxmot/trackers/deepocsort/deep_ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ class DeepOCSort(BaseTracker):
"""
def __init__(
self,
model_weights: Path,
reid_weights: Path,
device: torch.device,
fp16: bool,
half: bool,
per_class: bool = False,
det_thresh: float = 0.3,
max_age: int = 30,
Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(
KalmanBoxTracker.count = 1

self.model = ReidAutoBackend(
weights=model_weights, device=device, half=fp16
weights=reid_weights, device=device, half=half
).model
# "similarity transforms using feature point extraction, optical flow, and RANSAC"
self.cmc = get_cmc_method('sof')()
Expand Down
6 changes: 3 additions & 3 deletions boxmot/trackers/imprassoc/impr_assoc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ class ImprAssocTrack(BaseTracker):
"""
def __init__(
self,
model_weights: Path,
reid_weights: Path,
device: device,
fp16: bool,
half: bool,
per_class: bool = False,
track_high_thresh: float = 0.6,
track_low_thresh: float = 0.1,
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
self.with_reid = with_reid
if self.with_reid:
rab = ReidAutoBackend(
weights=model_weights, device=device, half=fp16
weights=reid_weights, device=device, half=half
)
self.model = rab.get_backend()

Expand Down
6 changes: 3 additions & 3 deletions boxmot/trackers/strongsort/strong_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class StrongSORT(object):
"""
def __init__(
self,
model_weights: Path,
reid_weights: Path,
device: device,
fp16: bool,
half: bool,
per_class: bool = False,
max_cos_dist=0.2,
max_iou_dist=0.7,
Expand All @@ -47,7 +47,7 @@ def __init__(

self.per_class = per_class
self.model = ReidAutoBackend(
weights=model_weights, device=device, half=fp16
weights=reid_weights, device=device, half=half
).model

self.tracker = Tracker(
Expand Down
Loading

0 comments on commit 5c44e69

Please sign in to comment.