Skip to content

Commit

Permalink
Merge pull request #1068 from mikel-brostrom/yolo-nas
Browse files Browse the repository at this point in the history
fix super-gradients custom model load
  • Loading branch information
mikel-brostrom authored Aug 11, 2023
2 parents 9e3a5c5 + aa9554c commit ee456e7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
10 changes: 10 additions & 0 deletions examples/detectors/yolo_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

from pathlib import Path

import numpy as np
import torch
from ultralytics.engine.results import Results
Expand Down Expand Up @@ -63,3 +65,11 @@ def preds_to_yolov8_results(self, path, preds, im, im0s, names):
orig_img=im0s[0],
names=names
)

def get_model_from_weigths(self, l, model):
model_type = None
for key in l:
if Path(key).stem in str(model.name):
model_type = str(Path(key).with_suffix(''))
break
return model_type
84 changes: 54 additions & 30 deletions examples/detectors/yolonas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import numpy as np
import torch
from super_gradients.common.object_names import Models
from super_gradients.training import models
from ultralytics.engine.results import Results
from ultralytics.utils import ops

from .yolo_interface import YoloInterface
from boxmot.utils import logger as LOGGER
from examples.detectors.yolo_interface import YoloInterface


class YoloNASStrategy(YoloInterface):
Expand Down Expand Up @@ -36,23 +38,38 @@ class YoloNASStrategy(YoloInterface):
def __init__(self, model, device, args):
self.args = args

self.model = models.get(
str(model),
pretrained_weights="coco"
).to(device)
avail_models = [x.lower() for x in list(Models.__dict__.keys())]
model_type = self.get_model_from_weigths(avail_models, model)

LOGGER.info(f'Loading {model_type} with {str(model)}')
if not model.exists() and model.stem == model_type:
LOGGER.info('Downloading pretrained weights...')
self.model = models.get(
model_type,
pretrained_weights="coco"
).to(device)
else:
self.model = models.get(
model_type,
num_classes=-1, # set your num classes
checkpoint_path=str(model)
).to(device)

self.device = device

@torch.no_grad()
def __call__(self, im, augment, visualize):

self.has_run = False
im = im[0].permute(1, 2, 0).cpu().numpy() * 255

def __call__(self, im, augment, visualize):
with torch.no_grad():
preds = self.model.predict(
im,
iou=0.5,
conf=0.7,
fuse_model=False
)[0].prediction

preds = next(iter(
self.model.predict(
# (1, 3, h, w) norm --> (h, w, 3) un-norm
im[0].permute(1, 2, 0).cpu().numpy() * 255,
iou=self.args.iou,
conf=self.args.conf
)
)).prediction # Returns a generator of the batch, which here is 1
preds = np.concatenate(
[
preds.bboxes_xyxy,
Expand All @@ -61,29 +78,36 @@ def __call__(self, im, augment, visualize):
], axis=1
)

preds = torch.from_numpy(preds).unsqueeze(0)

return preds

def warmup(self, imgsz):
pass

def postprocess(self, path, preds, im, im0s):
preds = torch.from_numpy(preds).unsqueeze(0)

results = []
for i, pred in enumerate(preds):

# scale from im to im0
pred[:, :4] = ops.scale_boxes(im.shape[2:], pred[:, :4], im0s[i].shape)

if self.args.classes: # Filter boxes by classes
pred = pred[np.isin(pred[:, 5], self.args.classes)]

r = Results(
path=path,
boxes=pred,
orig_img=im0s[i],
names=self.names
)

if pred is None:
pred = torch.empty((0, 6))
r = Results(
path=path,
boxes=pred,
orig_img=im0s[i],
names=self.names
)
results.append(r)
else:

pred[:, :4] = ops.scale_boxes(im.shape[2:], pred[:, :4], im0s[i].shape)

r = Results(
path=path,
boxes=pred,
orig_img=im0s[i],
names=self.names
)
results.append(r)

return results
15 changes: 7 additions & 8 deletions examples/detectors/yolox.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

from pathlib import Path

import gdown
import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import ops
from yolox.exp import get_exp
from yolox.utils import postprocess
from yolox.utils.model_utils import fuse_model

from boxmot.utils import logger as LOGGER
from examples.detectors.yolo_interface import YoloInterface

# default model weigths for these model names
Expand All @@ -23,7 +21,7 @@
}


class YoloXStrategy(DetectionPredictor, YoloInterface):
class YoloXStrategy(YoloInterface):
pt = False
stride = 32
fp16 = False
Expand Down Expand Up @@ -54,17 +52,18 @@ def __init__(self, model, device, args):
self.stride = 32 # max stride in YOLOX

# model_type one of: 'yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x'
for key in YOLOX_ZOO.keys():
if Path(key).stem in str(model.name):
model_type = str(Path(key).with_suffix(''))
break
model_type = self.get_model_from_weigths(YOLOX_ZOO.keys(), model)

if model_type == 'yolox_n':
exp = get_exp(None, 'yolox_nano')
else:
exp = get_exp(None, model_type)

LOGGER.info(f'Loading {model_type} with {str(model)}')

# download crowdhuman bytetrack models
if not model.exists() and model.stem == model_type:
LOGGER.info('Downloading pretrained weights...')
gdown.download(
url=YOLOX_ZOO[model_type + '.pt'],
output=str(model),
Expand Down

0 comments on commit ee456e7

Please sign in to comment.