Skip to content

Commit

Permalink
fix: keep logit_scale on same device (#710)
Browse files Browse the repository at this point in the history
* fix: keep logit_scale on cpu

* fix: use cuda when computing ranked score
  • Loading branch information
numb3r3 authored May 9, 2022
1 parent da87d13 commit bb520d1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
else:
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings).to(
self._device, non_blocking=True
)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings).to(
self._device, non_blocking=True
)

# normalized features
image_features = _img_da.embeddings / _img_da.embeddings.norm(
Expand Down

0 comments on commit bb520d1

Please sign in to comment.