Skip to content

Commit

Permalink
fix: mclip cuda device (#792)
Browse files Browse the repository at this point in the history
* fix: model device

* fix: enable eval mode

* fix: typo

* fix: minor revision
  • Loading branch information
numb3r3 authored Aug 2, 2022
1 parent 8681b88 commit 8b8082a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from clip_server.model.clip_model import CLIPModel

corresponding_clip_models = {
_CLIP_MODEL_MAPS = {
'M-CLIP/XLM-Roberta-Large-Vit-B-32': ('ViT-B-32', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vi-L-14': ('ViT-L-14', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vit-L-14': ('ViT-L-14', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ('ViT-B-16-plus-240', 'laion400m_e31'),
'M-CLIP/LABSE-Vit-L-14': ('ViT-L-14', 'openai'),
}
Expand Down Expand Up @@ -54,11 +54,15 @@ class MultilingualCLIPModel(CLIPModel):
def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
super().__init__(name, **kwargs)
self._mclip_model = MultilingualCLIP.from_pretrained(name)
self._mclip_model.to(device=device)
self._mclip_model.eval()

clip_name, clip_pretrained = corresponding_clip_models[name]
clip_name, clip_pretrained = _CLIP_MODEL_MAPS[name]
self._model = open_clip.create_model(
clip_name, pretrained=clip_pretrained, device=device, jit=jit
)
self._model.eval()

self._clip_name = clip_name

@property
Expand Down

0 comments on commit 8b8082a

Please sign in to comment.