diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index f099eb592d..ec79d9e5ed 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -24,7 +24,7 @@ ImagesClassificationPrediction, ClassificationPrediction, ) -from super_gradients.training.utils.utils import generate_batch +from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device from super_gradients.training.utils.media.video import load_video, includes_video_extension from super_gradients.training.utils.media.image import ImageSource, check_image_typing from super_gradients.training.utils.media.stream import WebcamStreaming @@ -68,9 +68,13 @@ def __init__( fuse_model: bool = True, dtype: Optional[torch.dtype] = None, ): - self.device = device or next(model.parameters()).device - self.model = model.to(self.device) + model_device: torch.device = infer_model_device(model=model) + if device: + device: torch.device = resolve_torch_device(device=device) + + self.device: torch.device = device or model_device self.dtype = dtype or next(model.parameters()).dtype + self.model = model.to(device) if device and device != model_device else model self.class_names = class_names if isinstance(image_processor, list): @@ -169,8 +173,12 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) :param images: Iterable of numpy arrays representing images. :return: Iterable of Results object, each containing the results of the prediction and the image. """ + # Make sure the model is on the correct device, as it might have been moved after init + model_device: torch.device = infer_model_device(model=self.model) + if self.device != model_device: + self.model = self.model.to(self.device) + images = list(images) # We need to load all the images into memory, and to reuse it afterwards. - self.model = self.model.to(self.device) # Make sure the model is on the correct device, as it might have been moved after init # Preprocess preprocessed_images, processing_metadatas = [], [] diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index 2307ff6511..3de562e43c 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -664,6 +664,24 @@ def infer_model_device(model: nn.Module) -> Optional[torch.device]: return None +def resolve_torch_device(device: Union[str, torch.device]) -> torch.device: + """ + Resolve the specified torch device. It accepts either a string or a torch.device object. + + This function takes the provided device identifier and returns a corresponding torch.device object, + which represents the device where a torch.Tensor will be allocated. + + :param device: A string or torch.device object representing the device (e.g., 'cpu', 'cuda', 'cuda:0'). + :return: A torch.device object representing the resolved device. + + Example: + >>> torch.cuda.set_device(5) + >>> str(resolve_torch_device("cuda")) + 'cuda:5' + """ + return torch.zeros([], device=device).device + + def check_model_contains_quantized_modules(model: nn.Module) -> bool: """ Check if the model contains any quantized modules.