Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I modified a part of the code to enable parallel inference with multiple num_batch #1113

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def perform_inference(self, image: np.ndarray):

if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}
if type(image) is list:

prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR

prediction_result = self.model(image, **kwargs) # YOLOv8 expects numpy arrays to have BGR
else :
prediction_result = self.model(image[:, :, ::-1], **kwargs)
if self.has_mask:
if not prediction_result[0].masks:
prediction_result[0].masks = Masks(
Expand Down Expand Up @@ -109,7 +111,10 @@ def perform_inference(self, image: np.ndarray):
prediction_result = [result.boxes.data for result in prediction_result]

self._original_predictions = prediction_result
self._original_shape = image.shape
if type(image) == list:
self._original_shape = image[0].shape
else:
self._original_shape = image.shape

@property
def category_names(self):
Expand Down
48 changes: 32 additions & 16 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
from typing import Generator, List, Optional, Union
import math

from PIL import Image

Expand Down Expand Up @@ -106,10 +107,11 @@ def get_prediction(
durations_in_seconds = dict()

# read image as pil
image_as_pil = read_image_as_pil(image)
# image_as_pil = read_image_as_pil(image)
# get prediction
time_start = time.time()
detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
# detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
detection_model.perform_inference(image)
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end

Expand All @@ -126,7 +128,6 @@ def get_prediction(
# postprocess matching predictions
if postprocess is not None:
object_prediction_list = postprocess(object_prediction_list)

time_end = time.time() - time_start
durations_in_seconds["postprocess"] = time_end

Expand Down Expand Up @@ -159,6 +160,7 @@ def get_sliced_prediction(
auto_slice_resolution: bool = True,
slice_export_prefix: Optional[str] = None,
slice_dir: Optional[str] = None,
num_batch: int = 1
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
Expand Down Expand Up @@ -225,8 +227,8 @@ def get_sliced_prediction(
# for profiling
durations_in_seconds = dict()

# currently only 1 batch supported
num_batch = 1
# # currently only 1 batch supported
# num_batch = 1
# create slices from full image
time_start = time.time()
slice_image_result = slice_image(
Expand Down Expand Up @@ -260,7 +262,8 @@ def get_sliced_prediction(
)

# create prediction input
num_group = int(num_slices / num_batch)
# num_group = int(num_slices / num_batch)
num_group = math.ceil(num_slices / num_batch)
if verbose == 1 or verbose == 2:
tqdm.write(f"Performing prediction on {num_slices} slices.")
object_prediction_list = []
Expand All @@ -270,24 +273,33 @@ def get_sliced_prediction(
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
if (group_ind * num_batch + image_ind) >= num_slices:
break
# image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
img_slice = slice_image_result.images[group_ind * num_batch + image_ind]
img_slice = img_slice[:,:,::-1]
image_list.append(img_slice)
shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
# perform batch prediction
num_full = len(image_list)
prediction_result = get_prediction(
image=image_list[0],
image=image_list,
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
shift_amount=shift_amount_list,
full_shape=[[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
]] * num_full,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)

# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())
for object_prediction_per in prediction_result.object_prediction_list:

if len(object_prediction_per) != 0: # if not empty
for object_prediction in object_prediction_per:
object_prediction_list.append(object_prediction.get_shifted_object_prediction())

# merge matching predictions during sliced prediction
if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
Expand All @@ -296,7 +308,7 @@ def get_sliced_prediction(
# perform standard prediction
if num_slices > 1 and perform_standard_pred:
prediction_result = get_prediction(
image=image,
image=[np.array(image)],
detection_model=detection_model,
shift_amount=[0, 0],
full_shape=[
Expand All @@ -307,7 +319,9 @@ def get_sliced_prediction(
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list.extend(prediction_result.object_prediction_list)
if len(prediction_result.object_prediction_list) != 0:
for _predicion_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_predicion_result)

Comment on lines +323 to 325
Copy link
Preview

Copilot AI Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Typo detected: '_predicion_result' should be renamed to '_prediction_result' for clarity.

Suggested change
for _predicion_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_predicion_result)
for _prediction_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_prediction_result)

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
# merge matching predictions
if len(object_prediction_list) > 1:
Expand Down Expand Up @@ -408,6 +422,7 @@ def predict(
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
num_batch: int = 1,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
**kwargs,
Expand Down Expand Up @@ -610,6 +625,7 @@ def predict(
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=1 if verbose else 0,
num_batch = num_batch,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
Expand Down
9 changes: 7 additions & 2 deletions sahi/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ def __init__(
image: Union[Image.Image, str, np.ndarray],
durations_in_seconds: Dict[str, Any] = dict(),
):
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size

if type(image) is list:
self.image = image
self.image_width, self.image_height = self.image[0].shape[:2]
else :
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size
self.object_prediction_list: List[ObjectPrediction] = object_prediction_list
self.durations_in_seconds = durations_in_seconds

Expand Down
Loading