Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove autocast for cpu #491

Merged
merged 4 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/marqo/s2_inference/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,13 @@ def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]

self.image_input_processed = torch.stack([self.preprocess(_img).to(self.device) for _img in image_input])

with torch.no_grad(), torch.autocast(device_type="cuda" if self.device.startswith("cuda") else "cpu"):
outputs = self.model.encode_image(self.image_input_processed).to(torch.float32)
with torch.no_grad():
if self.device.startswith("cuda"):
with torch.cuda.amp.autocast():
outputs = self.model.encode_image(self.image_input_processed).to(torch.float32)
else:
outputs = self.model.encode_image(self.image_input_processed).to(torch.float32)


if normalize:
_shape_before = outputs.shape
Expand All @@ -483,8 +488,12 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT

text = self.tokenizer(sentence).to(self.device)

with torch.no_grad(), torch.autocast(device_type="cuda" if self.device.startswith("cuda") else "cpu"):
outputs = self.model.encode_text(text).to(torch.float32)
with torch.no_grad():
if self.device.startswith("cuda"):
with torch.cuda.amp.autocast():
outputs = self.model.encode_text(text).to(torch.float32)
else:
outputs = self.model.encode_text(text).to(torch.float32)

if normalize:
_shape_before = outputs.shape
Expand Down
12 changes: 9 additions & 3 deletions tests/s2_inference/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import unittest
import os
import torch

from unittest.mock import MagicMock, patch
from marqo.s2_inference.types import FloatTensor
from marqo.s2_inference.s2_inference import clear_loaded_models, get_model_properties_from_registry
from marqo.s2_inference.model_registry import load_model_properties, _get_open_clip_properties
from marqo.s2_inference.s2_inference import _convert_tensor_to_numpy
import numpy as np
import functools
from unittest.mock import MagicMock

from marqo.s2_inference.s2_inference import (
_check_output_type, vectorise,
Expand Down Expand Up @@ -432,3 +430,11 @@ def test_model_un_normalization(self):

clear_loaded_models()

@patch("torch.cuda.amp.autocast")
def test_autocast_called_when_cuda(self, mock_autocast):
names = self.open_clip_test_model
contents = ['this is a test sentence. so is this.', "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg"]
for model_name in names:
for content in contents:
vectorise(model_name=model_name, content=content, device="cpu")
mock_autocast.assert_not_called()
13 changes: 12 additions & 1 deletion tests/s2_inference/test_large_model_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,15 @@ def test_cuda_encode_type(self):
assert isinstance(output_v, np.ndarray)

del model
clear_loaded_models()
clear_loaded_models()

@patch("torch.cuda.amp.autocast")
def test_autocast_called_in_open_clip(self, mock_autocast):
names = ["open_clip/ViT-B-32/laion400m_e31"]
contents = ['this is a test sentence. so is this.',
"https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg"]
for model_name in names:
for content in contents:
vectorise(model_name=model_name, content=content, device="cuda")
mock_autocast.assert_called_once()
mock_autocast.reset_mock()