Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into refactor-punica-kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Jul 30, 2024
2 parents a6d9e46 + f058403 commit 37c3cbd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
6 changes: 3 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--preemption-mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by block '
'swapping; If \'swap\', the engine performs preemption by block '
'swapping.')
help='If \'recompute\', the engine performs preemption by '
'recomputing; If \'swap\', the engine performs preemption by '
'block swapping.')

parser.add_argument(
"--served-model-name",
Expand Down
7 changes: 2 additions & 5 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
hidden_states = self.linear(image_features)
return hidden_states


Expand Down
27 changes: 18 additions & 9 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

logger = init_logger(__name__)

_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
Expand Down Expand Up @@ -414,10 +416,7 @@ def _prepare_sample(
best_of = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
Expand Down Expand Up @@ -678,13 +677,23 @@ def forward(
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)

logits = logits / t.unsqueeze(dim=1)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))

# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids


Expand Down

0 comments on commit 37c3cbd

Please sign in to comment.