Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def _cross_encoding_score(

input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

pooling_params = PoolingParams()
pooling_params = PoolingParams(use_cross_encoder=True)

tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
Expand Down
43 changes: 32 additions & 11 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_classification_activation_function,
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

Expand Down Expand Up @@ -388,15 +389,14 @@ def __init__(
self.classifier = classifier
self.pooler = pooler

if config.task == "score":
self.default_activation_function = \
get_cross_encoder_activation_function(config.hf_config)
elif config.task == "classify":
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax()
else:
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
self.classification_act_fn = get_classification_activation_function(
config.hf_config)
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config)

def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)

def get_prompt_lens(
self,
Expand Down Expand Up @@ -446,8 +446,29 @@ def forward(
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)

# shape: (batch_size, num_labels)
scores = self.default_activation_function(pooled_output)
if isinstance(pooling_metadata, V0PoolingMetadata):
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for _, pooling_param in pooling_metadata.seq_groups
]
else:
assert isinstance(pooled_data, list)
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for pooling_param in pooling_metadata.pooling_params
]

# shape of scores: (batch_size, num_labels)
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
Comment on lines +461 to +462
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
if len(set(use_cross_encoder_list)) == 1:

I think we can simplify the condition here.

act_fn = self._get_act_fn(use_cross_encoder_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(use_cross_encoder)(vecs)
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
pooled_data)
])

pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)
5 changes: 0 additions & 5 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
Expand Down Expand Up @@ -462,9 +460,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

self.default_activation_function = \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used by the current module so I removed them

get_cross_encoder_activation_function(config)

self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
Expand Down Expand Up @@ -178,9 +176,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

self.default_activation_function = \
get_cross_encoder_activation_function(config)

self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
Expand Down
3 changes: 3 additions & 0 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class PoolingParams(
"""

dimensions: Optional[int] = None
use_cross_encoder: bool = False
additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data)

def verify(self, model_config: "ModelConfig") -> None:
Expand All @@ -54,6 +56,7 @@ def verify(self, model_config: "ModelConfig") -> None:
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})")

def __post_init__(self) -> None:
Expand Down
20 changes: 11 additions & 9 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,24 +866,26 @@ def try_get_generation_config(
return None


def get_cross_encoder_activation_function(config: PretrainedConfig):
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()


def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \
config.sentence_transformers:
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]

elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function

if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()

return get_classification_activation_function(config)


def try_get_safetensors_metadata(
Expand Down