From 4811fbc0c9f8ebd1da3b43fbf78457ac27b0aa8d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 4 Sep 2024 12:43:17 +0000 Subject: [PATCH 1/3] Fix missing `post_layernorm` in CLIP --- vllm/model_executor/models/clip.py | 22 ++++++++++++++++++---- vllm/model_executor/models/siglip.py | 23 +++++++++-------------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b581a501e3333..0d9e69cd39064 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -355,6 +355,12 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) + if len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + def forward( self, pixel_values: torch.Tensor, @@ -364,7 +370,10 @@ def forward( hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.encoder(inputs_embeds=hidden_states) - return hidden_states + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) class CLIPVisionModel(nn.Module): @@ -386,9 +395,12 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) - def forward(self, pixel_values: Optional[torch.Tensor] = None): + @property + def need_post_layernorm(self) -> bool: + return self.vision_model.post_layernorm is not None - return self.vision_model(pixel_values=pixel_values) + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + return self.vision_model(pixel_values) @property def device(self): @@ -408,8 +420,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + if ("vision_model.post_layernorm" in name + and not self.need_post_layernorm): continue + # omit layers when num_hidden_layers_override is set if "vision_model.encoder.layers." in name: layer_idx = int(name.split(".")[3]) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 0bee75e2f0cbb..ffd89fd2114c0 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -443,27 +443,19 @@ def __init__( self.config = config embed_dim = config.hidden_size - if (num_hidden_layers_override is None - or num_hidden_layers_override == config.num_hidden_layers): - self.need_post_layernorm = True - elif num_hidden_layers_override > config.num_hidden_layers: - raise ValueError( - "num_hidden_layers_override cannot be greater than " - "num_hidden_layers") - else: - self.need_post_layernorm = False - self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - if self.need_post_layernorm: + + if len(self.encoder.layers) == config.num_hidden_layers: self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: - self.post_layernorm = nn.Identity() + self.post_layernorm = None + self.use_head = (True if not hasattr(config, "vision_use_head") else config.vision_use_head) if self.use_head: @@ -482,6 +474,9 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) + if self.post_layernorm is None: + return encoder_outputs + last_hidden_state = self.post_layernorm(encoder_outputs) # TODO: add this back when pooled_output is used in inference # if self.use_head: @@ -512,8 +507,8 @@ def __init__( ) @property - def need_post_layernorm(self): - return self.vision_model.need_post_layernorm + def need_post_layernorm(self) -> bool: + return self.vision_model.post_layernorm is not None def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding From 26feda6a6f31541209f51279a87324f6631d7f41 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 4 Sep 2024 12:54:32 +0000 Subject: [PATCH 2/3] Fix missing shard weight loading in SigLIP --- vllm/model_executor/models/siglip.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index ffd89fd2114c0..5b02a523e005f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -524,6 +524,12 @@ def forward( ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -539,7 +545,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From ad2ac42867fc11626b21a3a64ee2b0f8bc7b3afd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 10 Sep 2024 06:12:42 +0000 Subject: [PATCH 3/3] Address comments --- vllm/model_executor/models/clip.py | 13 ++++++++++--- vllm/model_executor/models/siglip.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 4f2171a05273c..078928f281c26 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -355,10 +355,17 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) - if len(self.encoder.layers) == config.num_hidden_layers: + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory self.post_layernorm = None def forward( @@ -396,7 +403,7 @@ def __init__(self, num_hidden_layers_override=num_hidden_layers_override) @property - def need_post_layernorm(self) -> bool: + def _require_post_layernorm(self) -> bool: return self.vision_model.post_layernorm is not None def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: @@ -421,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel if ("vision_model.post_layernorm" in name - and not self.need_post_layernorm): + and not self._require_post_layernorm): continue # omit layers when num_hidden_layers_override is set diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index a1b34288dea80..f7976eba7420b 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -450,10 +450,17 @@ def __init__( num_hidden_layers_override=num_hidden_layers_override, ) - if len(self.encoder.layers) == config.num_hidden_layers: + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory self.post_layernorm = None self.use_head = (True if not hasattr(config, "vision_use_head") else @@ -507,7 +514,7 @@ def __init__( ) @property - def need_post_layernorm(self) -> bool: + def _require_post_layernorm(self) -> bool: return self.vision_model.post_layernorm is not None def get_input_embeddings(self) -> nn.Module: @@ -536,7 +543,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel if ("vision_model.post_layernorm" in name - and not self.need_post_layernorm): + and not self._require_post_layernorm): continue # omit layers when num_hidden_layers_override is set