Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Post-layernorm override and quant config in vision encoder #9217

Merged
merged 20 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
Expand All @@ -21,10 +22,12 @@ def __init__(
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []

if self.weight_bits != 4:
raise ValueError(
Expand All @@ -35,7 +38,8 @@ def __init__(
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")

def get_name(self) -> str:
return "awq"
Expand All @@ -61,18 +65,26 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]


def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)


class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.

Expand Down
87 changes: 59 additions & 28 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def input_processor_for_blip(
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

def __init__(self, config: BlipVisionConfig):
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__()

self.config = config
Expand Down Expand Up @@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):

def __init__(
self,
config: BlipVisionConfig,
config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
Expand All @@ -189,11 +190,13 @@ def __init__(
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.projection",
)

self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -235,9 +238,12 @@ def forward(

class BlipMLP(nn.Module):

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

self.config = config
Expand All @@ -246,11 +252,13 @@ def __init__(self,
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
Expand All @@ -262,24 +270,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class BlipEncoderLayer(nn.Module):

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

Expand Down Expand Up @@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
config: BlipConfig
"""

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()

self.config = config
Expand All @@ -321,8 +340,10 @@ def __init__(self,
num_hidden_layers = num_hidden_layers_override

self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
BlipEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])

def forward(self, inputs_embeds: torch.Tensor):
Expand All @@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -354,19 +380,24 @@ def __init__(self,
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)

num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:

# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers

if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def __init__(self,
self.multimodal_config = multimodal_config

# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config)
self.vision_model = BlipVisionModel(config.vision_config, quant_config)

self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
Expand Down
Loading