From d7e156bd1ae2467e9ea1dbc44f31da0ed2296aee Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 7 Jul 2021 22:50:27 +0530 Subject: [PATCH] fix loading clip vision model (#12566) --- .../jax-projects/hybrid_clip/configuration_hybrid_clip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py index 1a2c51f554a0b6..5272ac44a1a884 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py @@ -75,6 +75,10 @@ def __init__(self, projection_dim=512, **kwargs): if vision_model_type == "clip": self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config + elif vision_model_type == "clip_vision_model": + from transformers import CLIPVisionConfig + + self.vision_config = CLIPVisionConfig(**vision_config) else: self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)