Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 91 additions & 41 deletions vllm/model_executor/models/step3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand All @@ -33,6 +34,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
Expand All @@ -659,20 +662,42 @@ def __init__(self,

self.scale = self.head_dim**-0.5

tp_size = get_tensor_model_parallel_world_size()
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.qkv_proj = QKVParallelLinear(self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)

self.q_size = self.num_heads * self.head_dim

if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
Expand Down Expand Up @@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
Expand All @@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.use_data_parallel = use_data_parallel
self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention(config,
quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = Step3VisionAttention(
config,
quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=self.use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.mlp = Step3VisionMLP(config,
quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)

Expand All @@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.layers = nn.ModuleList([
Step3VisionEncoderLayer(config,
quant_config,
prefix=f"{prefix}.layers.{i}")
prefix=f"{prefix}.layers.{i}",
use_data_parallel=self.use_data_parallel)
for i in range(config.num_hidden_layers)
])

Expand All @@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(config,
quant_config,
prefix=f"{prefix}.transformer")
self.transformer = Step3VisionEncoder(
config,
quant_config,
prefix=f"{prefix}.transformer",
use_data_parallel=self.use_data_parallel)

def forward(
self,
pixel_values: torch.Tensor,
):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states)
if self.use_data_parallel:
hidden_states = run_dp_sharded_vision_model(
hidden_states, self.transformer)
else:
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states


Expand Down Expand Up @@ -836,13 +884,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)

if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer(config.vision_config,
None,
prefix=maybe_prefix(
prefix,
"vision_model"))
self.vision_model = Step3VisionTransformer(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel)
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
Expand Down