Skip to content

Commit e78e66c

Browse files
committed
Fix OBB prediction and update Ultralytics demo notebook
- Add OBB (Oriented Bounding Box) prediction example to inference notebook - Enhance visualization for OBB predictions in cv utils - Update AutoDetectionModel and prediction methods to support OBB models - Bump package version to 0.11.22 - Improve demo notebook with additional test image and simplified imports
1 parent 9f5cdb7 commit e78e66c

7 files changed

+295
-171
lines changed

demo/inference_for_ultralytics.ipynb

+206-105
Large diffs are not rendered by default.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sahi"
3-
version = "0.11.21"
3+
version = "0.11.22"
44
readme = "README.md"
55
description = "A vision library for performing sliced inference on large images/small objects"
66
requires-python = ">=3.8"

sahi/auto_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def from_pretrained(
3939
4040
Args:
4141
model_type: str
42-
Name of the detection framework (example: "yolov5", "mmdet", "detectron2")
42+
Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
4343
model_path: str
4444
Path of the detection model (ex. 'model.pt')
4545
config_path: str
@@ -58,8 +58,10 @@ def from_pretrained(
5858
If True, automatically loads the model at initialization
5959
image_size: int
6060
Inference input size.
61+
6162
Returns:
6263
Returns an instance of a DetectionModel
64+
6365
Raises:
6466
ImportError: If given {model_type} framework is not installed
6567
"""

sahi/models/torchvision.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def _create_object_prediction_list_from_original_predictions(
178178

179179
for ind in range(len(boxes)):
180180
if masks is not None:
181-
mask = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
181+
segmentation = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
182182
else:
183-
mask = None
183+
segmentation = None
184184

185185
object_prediction = ObjectPrediction(
186186
bbox=boxes[ind],
187-
segmentation=mask,
187+
segmentation=segmentation,
188188
category_id=int(category_ids[ind]),
189189
category_name=self.category_mapping[str(int(category_ids[ind]))],
190190
shift_amount=shift_amount,

sahi/models/ultralytics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sahi.models.base import DetectionModel
1212
from sahi.prediction import ObjectPrediction
1313
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
14-
from sahi.utils.cv import get_coco_segmentation_from_bool_mask, get_coco_segmentation_from_obb_points
14+
from sahi.utils.cv import get_coco_segmentation_from_bool_mask
1515
from sahi.utils.import_utils import check_requirements
1616

1717
logger = logging.getLogger(__name__)
@@ -207,7 +207,7 @@ def _create_object_prediction_list_from_original_predictions(
207207
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
208208
else: # is_obb
209209
obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction
210-
segmentation = get_coco_segmentation_from_obb_points(obb_points)
210+
segmentation = [obb_points.reshape(-1).tolist()]
211211

212212
if len(segmentation) == 0:
213213
continue

sahi/predict.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def get_prediction(
113113
time_end = time.time() - time_start
114114
durations_in_seconds["prediction"] = time_end
115115

116+
if full_shape is None:
117+
full_shape = [image_as_pil.height, image_as_pil.width]
118+
116119
# process prediction
117120
time_start = time.time()
118121
# works only with 1 batch
@@ -239,19 +242,21 @@ def get_sliced_prediction(
239242
overlap_width_ratio=overlap_width_ratio,
240243
auto_slice_resolution=auto_slice_resolution,
241244
)
245+
from sahi.models.ultralytics import UltralyticsDetectionModel
242246

243247
num_slices = len(slice_image_result)
244248
time_end = time.time() - time_start
245249
durations_in_seconds["slice"] = time_end
246250

251+
if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
252+
# Only NMS is supported for OBB model outputs
253+
postprocess_type = "NMS"
254+
247255
# init match postprocess instance
248256
if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
249257
raise ValueError(
250258
f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
251259
)
252-
elif postprocess_type == "UNIONMERGE":
253-
# deprecated in v0.9.3
254-
raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")
255260
postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
256261
postprocess = postprocess_constructor(
257262
match_threshold=postprocess_match_threshold,

sahi/utils/cv.py

+72-56
Original file line numberDiff line numberDiff line change
@@ -540,68 +540,83 @@ def visualize_object_predictions(
540540
# set text_size for category names
541541
text_size = text_size or rect_th / 3
542542

543-
# add masks to image if present
543+
# add masks or obb polygons to image if present
544544
for object_prediction in object_prediction_list:
545545
# deepcopy object_prediction_list so that original is not altered
546546
object_prediction = object_prediction.deepcopy()
547-
# visualize masks if present
548-
if object_prediction.mask is not None:
549-
# deepcopy mask so that original is not altered
550-
mask = object_prediction.mask.bool_mask
551-
# set color
552-
if colors is not None:
553-
color = colors(object_prediction.category.id)
554-
# draw mask
555-
rgb_mask = apply_color_mask(mask, color or (0, 0, 0))
556-
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)
557-
558-
# add bboxes to image if present
559-
for object_prediction in object_prediction_list:
560-
# deepcopy object_prediction_list so that original is not altered
561-
object_prediction = object_prediction.deepcopy()
562-
563-
bbox = object_prediction.bbox.to_xyxy()
564-
category_name = object_prediction.category.name
565-
score = object_prediction.score.value
566-
547+
# arange label to be displayed
548+
label = f"{object_prediction.category.name}"
549+
if not hide_conf:
550+
label += f" {object_prediction.score.value:.2f}"
567551
# set color
568552
if colors is not None:
569553
color = colors(object_prediction.category.id)
570-
# set bbox points
571-
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
572-
# visualize boxes
573-
cv2.rectangle(
574-
image,
575-
point1,
576-
point2,
577-
color=color or (0, 0, 0),
578-
thickness=rect_th,
579-
)
580-
581-
if not hide_labels:
582-
# arange bounding box text location
583-
label = f"{category_name}"
584-
585-
if not hide_conf:
586-
label += f" {score:.2f}"
587-
588-
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
589-
0
590-
] # label width, height
591-
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
592-
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
593-
# add bounding box text
594-
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
595-
cv2.putText(
554+
# visualize masks or obb polygons if present
555+
has_mask = object_prediction.mask is not None
556+
is_obb_pred = False
557+
if has_mask:
558+
segmentation = object_prediction.mask.segmentation
559+
if len(segmentation) == 1 and len(segmentation[0]) == 8:
560+
is_obb_pred = True
561+
562+
if is_obb_pred:
563+
points = np.array(segmentation).reshape((-1, 1, 2)).astype(np.int32)
564+
cv2.polylines(image, [points], isClosed=True, color=color or (0, 0, 0), thickness=rect_th)
565+
566+
if not hide_labels:
567+
lowest_point = points[points[:, :, 1].argmax()][0]
568+
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
569+
outside = lowest_point[1] - box_height - 3 >= 0
570+
text_bg_point1 = (lowest_point[0], lowest_point[1] - box_height - 3 if outside else lowest_point[1] + 3)
571+
text_bg_point2 = (lowest_point[0] + box_width, lowest_point[1])
572+
cv2.rectangle(image, text_bg_point1, text_bg_point2, color or (0, 0, 0), thickness=-1, lineType=cv2.LINE_AA)
573+
cv2.putText(
574+
image,
575+
label,
576+
(lowest_point[0], lowest_point[1] - 2 if outside else lowest_point[1] + box_height + 2),
577+
0,
578+
text_size,
579+
(255, 255, 255),
580+
thickness=text_th,
581+
)
582+
else:
583+
# draw mask
584+
rgb_mask = apply_color_mask(object_prediction.mask.bool_mask, color or (0, 0, 0))
585+
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)
586+
587+
# add bboxes to image if is_obb_pred=False
588+
if not is_obb_pred:
589+
bbox = object_prediction.bbox.to_xyxy()
590+
591+
# set bbox points
592+
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
593+
# visualize boxes
594+
cv2.rectangle(
596595
image,
597-
label,
598-
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
599-
0,
600-
text_size,
601-
(255, 255, 255),
602-
thickness=text_th,
596+
point1,
597+
point2,
598+
color=color or (0, 0, 0),
599+
thickness=rect_th,
603600
)
604601

602+
if not hide_labels:
603+
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
604+
0
605+
] # label width, height
606+
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
607+
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
608+
# add bounding box text
609+
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
610+
cv2.putText(
611+
image,
612+
label,
613+
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
614+
0,
615+
text_size,
616+
(255, 255, 255),
617+
thickness=text_th,
618+
)
619+
605620
# export if output_dir is present
606621
if output_dir is not None:
607622
# export image with predictions
@@ -614,7 +629,7 @@ def visualize_object_predictions(
614629
return {"image": image, "elapsed_time": elapsed_time}
615630

616631

617-
def get_coco_segmentation_from_bool_mask(bool_mask):
632+
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> List[List[float]]:
618633
"""
619634
Convert boolean mask to coco segmentation format
620635
[
@@ -712,12 +727,13 @@ def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[f
712727
obb_points: np.ndarray
713728
OBB points tensor from ultralytics.engine.results.OBB
714729
Shape: (4, 2) containing 4 points with (x,y) coordinates each
730+
715731
Returns:
716732
List[List[float]]: Polygon points in COCO format
717-
[[x1, y1, x2, y2, x3, y3, x4, y4, x1, y1], [...], ...]
733+
[[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]
718734
"""
719735
# Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
720-
points = obb_points.reshape(-1).tolist()
736+
points = obb_points.reshape(-1).tolist() #
721737

722738
# Create polygon from points and close it by repeating first point
723739
polygons = []

0 commit comments

Comments
 (0)