diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py index 005a84e61..8c84cbbe1 100644 --- a/src/marqo/s2_inference/clip_utils.py +++ b/src/marqo/s2_inference/clip_utils.py @@ -466,8 +466,13 @@ def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]] self.image_input_processed = torch.stack([self.preprocess(_img).to(self.device) for _img in image_input]) - with torch.no_grad(), torch.autocast(device_type="cuda" if self.device.startswith("cuda") else "cpu"): - outputs = self.model.encode_image(self.image_input_processed).to(torch.float32) + with torch.no_grad(): + if self.device.startswith("cuda"): + with torch.cuda.amp.autocast(): + outputs = self.model.encode_image(self.image_input_processed).to(torch.float32) + else: + outputs = self.model.encode_image(self.image_input_processed).to(torch.float32) + if normalize: _shape_before = outputs.shape @@ -483,8 +488,12 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT text = self.tokenizer(sentence).to(self.device) - with torch.no_grad(), torch.autocast(device_type="cuda" if self.device.startswith("cuda") else "cpu"): - outputs = self.model.encode_text(text).to(torch.float32) + with torch.no_grad(): + if self.device.startswith("cuda"): + with torch.cuda.amp.autocast(): + outputs = self.model.encode_text(text).to(torch.float32) + else: + outputs = self.model.encode_text(text).to(torch.float32) if normalize: _shape_before = outputs.shape diff --git a/tests/s2_inference/test_encoding.py b/tests/s2_inference/test_encoding.py index e25e3e91d..f3fbf5388 100644 --- a/tests/s2_inference/test_encoding.py +++ b/tests/s2_inference/test_encoding.py @@ -1,14 +1,12 @@ import unittest -import os import torch - +from unittest.mock import MagicMock, patch from marqo.s2_inference.types import FloatTensor from marqo.s2_inference.s2_inference import clear_loaded_models, get_model_properties_from_registry from marqo.s2_inference.model_registry import load_model_properties, _get_open_clip_properties from marqo.s2_inference.s2_inference import _convert_tensor_to_numpy import numpy as np import functools -from unittest.mock import MagicMock from marqo.s2_inference.s2_inference import ( _check_output_type, vectorise, @@ -432,3 +430,11 @@ def test_model_un_normalization(self): clear_loaded_models() + @patch("torch.cuda.amp.autocast") + def test_autocast_called_when_cuda(self, mock_autocast): + names = self.open_clip_test_model + contents = ['this is a test sentence. so is this.', "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg"] + for model_name in names: + for content in contents: + vectorise(model_name=model_name, content=content, device="cpu") + mock_autocast.assert_not_called() diff --git a/tests/s2_inference/test_large_model_encoding.py b/tests/s2_inference/test_large_model_encoding.py index bf360d009..f514f39bc 100644 --- a/tests/s2_inference/test_large_model_encoding.py +++ b/tests/s2_inference/test_large_model_encoding.py @@ -200,4 +200,15 @@ def test_cuda_encode_type(self): assert isinstance(output_v, np.ndarray) del model - clear_loaded_models() \ No newline at end of file + clear_loaded_models() + + @patch("torch.cuda.amp.autocast") + def test_autocast_called_in_open_clip(self, mock_autocast): + names = ["open_clip/ViT-B-32/laion400m_e31"] + contents = ['this is a test sentence. so is this.', + "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg"] + for model_name in names: + for content in contents: + vectorise(model_name=model_name, content=content, device="cuda") + mock_autocast.assert_called_once() + mock_autocast.reset_mock() \ No newline at end of file