From 8b8082a939f67f7ea01cc9f55ebce9c5368ebe1a Mon Sep 17 00:00:00 2001 From: felix-wang <35718120+numb3r3@users.noreply.github.com> Date: Tue, 2 Aug 2022 15:28:55 +0800 Subject: [PATCH] fix: mclip cuda device (#792) * fix: model device * fix: enable eval mode * fix: typo * fix: minor revision --- server/clip_server/model/mclip_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index c5b9058d5..149de0482 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -6,9 +6,9 @@ from clip_server.model.clip_model import CLIPModel -corresponding_clip_models = { +_CLIP_MODEL_MAPS = { 'M-CLIP/XLM-Roberta-Large-Vit-B-32': ('ViT-B-32', 'openai'), - 'M-CLIP/XLM-Roberta-Large-Vi-L-14': ('ViT-L-14', 'openai'), + 'M-CLIP/XLM-Roberta-Large-Vit-L-14': ('ViT-L-14', 'openai'), 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ('ViT-B-16-plus-240', 'laion400m_e31'), 'M-CLIP/LABSE-Vit-L-14': ('ViT-L-14', 'openai'), } @@ -54,11 +54,15 @@ class MultilingualCLIPModel(CLIPModel): def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): super().__init__(name, **kwargs) self._mclip_model = MultilingualCLIP.from_pretrained(name) + self._mclip_model.to(device=device) + self._mclip_model.eval() - clip_name, clip_pretrained = corresponding_clip_models[name] + clip_name, clip_pretrained = _CLIP_MODEL_MAPS[name] self._model = open_clip.create_model( clip_name, pretrained=clip_pretrained, device=device, jit=jit ) + self._model.eval() + self._clip_name = clip_name @property