Skip to content

Commit

Permalink
Merge pull request #1052 from mikel-brostrom/simplify-track
Browse files Browse the repository at this point in the history
enable other trackers
  • Loading branch information
mikel-brostrom authored Aug 3, 2023
2 parents dfbec5a + 99770b4 commit 38833a4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
3 changes: 3 additions & 0 deletions boxmot/motion/cmc/sof.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def apply(self, img, dets):
except Exception as e:
LOGGER.warning(f'Affine matrix could not be generated: {e}')
return H
finally:
if H is None:
return np.eye(2, 3)

if self.draw_optical_flow:
self.warped_img = cv2.warpAffine(self.prev_img, H, (w, h), flags=cv2.INTER_LINEAR)
Expand Down
26 changes: 14 additions & 12 deletions examples/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import torch

from boxmot import DeepOCSORT
from boxmot.tracker_zoo import create_tracker
from boxmot.utils import EXAMPLES, ROOT, WEIGHTS
from boxmot.utils.checks import TestRequirements
from examples.detectors import get_yolo_inferer

__tr = TestRequirements()
__tr.check_packages(('git+https://github.com/mikel-brostrom/ultralytics.git',)) # install
__tr.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install

from ultralytics import YOLO
from ultralytics.yolo.data.utils import VID_FORMATS
Expand All @@ -27,18 +27,20 @@ def on_predict_start(predictor, persist=False):
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
predictor.args.tracking_config = \
tracking_config = \
ROOT /\
'boxmot' /\
'configs' /\
'deepocsort.yaml'
(predictor.custom_args.tracking_method + '.yaml')
trackers = []
for i in range(predictor.dataset.bs):
tracker = DeepOCSORT(
model_weights=Path(WEIGHTS / 'osnet_x0_25_msmt17.pt'),
device=predictor.device,
fp16=False,
per_class=False
tracker = create_tracker(
predictor.custom_args.tracking_method,
tracking_config,
predictor.custom_args.reid_model,
predictor.device,
predictor.custom_args.half,
predictor.custom_args.per_class
)
trackers.append(tracker)

Expand All @@ -52,7 +54,6 @@ def run(args):
yolo = YOLO(
'yolov8n.pt',
)
print(yolo.__dict__.keys())

results = yolo.track(
source=args.source,
Expand Down Expand Up @@ -83,9 +84,11 @@ def run(args):
yolo.predictor.args.name = args.name
yolo.predictor.args.exist_ok = args.exist_ok
yolo.predictor.args.classes = args.classes
yolo.predictor.custom_args = args

for frame_idx, r in enumerate(results):
if len(r.boxes.data) != 0:

if r.boxes.data.shape[1] == 7:

if yolo.predictor.source_type.webcam or args.source.endswith(VID_FORMATS):
p = yolo.predictor.save_dir / 'mot' / (args.source + '.txt')
Expand All @@ -99,7 +102,6 @@ def run(args):
yolo.predictor.mot_txt_path,
r,
frame_idx,
frame_idx,
)

if args.save_id_crops:
Expand Down
3 changes: 1 addition & 2 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from ultralytics.yolo.utils import ops


def write_mot_results(txt_path, results, frame_idx, i):
def write_mot_results(txt_path, results, frame_idx):
nr_dets = len(results.boxes)
frame_idx = torch.full((1, 1), frame_idx + 1)
frame_idx = frame_idx.repeat(nr_dets, 1)
dont_care = torch.full((nr_dets, 1), -1)
i = torch.full((nr_dets, 1), i)
mot = torch.cat([
frame_idx,
results.boxes.id.unsqueeze(1).to('cpu'),
Expand Down

0 comments on commit 38833a4

Please sign in to comment.