Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_size option to predict #1273

Merged
merged 3 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def _get_pipeline(self, fuse_model: bool = True) -> ClassificationPipeline:
)
return pipeline

def predict(self, images: ImageSource, fuse_model: bool = True) -> ImagesPredictions:
def predict(self, images: ImageSource, batch_size: int = 32, fuse_model: bool = True) -> ImagesPredictions:
"""Predict an image or a list of images.

:param images: Images to predict.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(fuse_model=fuse_model)
return pipeline(images) # type: ignore
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,25 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
def predict(
self,
images: ImageSource,
iou: Optional[float] = None,
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,25 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
def predict(
self,
images: ImageSource,
iou: Optional[float] = None,
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,25 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
def predict(
self,
images: ImageSource,
iou: Optional[float] = None,
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,16 +607,17 @@ def _get_pipeline(self, conf: Optional[float] = None, fuse_model: bool = True) -
)
return pipeline

def predict(self, images: ImageSource, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesPoseEstimationPrediction:
def predict(self, images: ImageSource, conf: Optional[float] = None, batch_size: int = 32, fuse_model: bool = True) -> ImagesPoseEstimationPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_siz
- List: A list of images of any of the above image types (list of videos not supported).

:param inputs: inputs to the model, which can be any of the above-mentioned types.
:param batch_size: Number of images to be processed at the same time.
:param batch_size: Maximum number of images to process at the same time.
:return: Results of the prediction.
"""

Expand Down