Skip to content

Commit

Permalink
[Misc] Various simplifications and typing fixes (vllm-project#5368)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and joerunde committed Jun 13, 2024
1 parent 4679416 commit 6599db4
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 90 deletions.
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def process_outputs(self, sequence_group: SequenceGroup,

# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
samples = [output.samples[0] for output in outputs]

# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,10 @@ def _create_output(

# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
torch.where(accepted_mask,
draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)

# Fill the last column.
# We check output directly as accepted may have True values inconsistent
Expand Down
39 changes: 12 additions & 27 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def score_proposals(

target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

Expand Down Expand Up @@ -140,8 +140,7 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -167,30 +166,16 @@ def _contract_batch(
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs

target_token_ids = target_token_ids.squeeze().reshape(
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)

all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
Expand Down
14 changes: 7 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
Expand Down Expand Up @@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None

target_worker = Worker(*args, **kwargs)
Expand Down Expand Up @@ -109,12 +110,11 @@ def create_worker(
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))

return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, ))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens))

def __init__(
self,
Expand Down
45 changes: 18 additions & 27 deletions vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _split_by_proposal_len(
nonzero_proposal_len_indices,
)

def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
@staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
Expand Down Expand Up @@ -207,7 +208,7 @@ def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
Expand All @@ -218,25 +219,19 @@ def _merge_outputs(
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor

sampler_output = maybe_sampler_output
Expand All @@ -246,18 +241,14 @@ def _merge_outputs(
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]

entire_proposal_tokens = torch.full(
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs

Expand Down
11 changes: 3 additions & 8 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple

import torch

from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SequenceOutput)

SeqId = int

Expand All @@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


def get_all_num_logprobs(
Expand Down Expand Up @@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Type

from transformers import PretrainedConfig

Expand All @@ -9,7 +9,7 @@

logger = init_logger(__name__)

_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
Expand Down Expand Up @@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
return config
30 changes: 15 additions & 15 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,28 +527,13 @@ def _prepare_model_input(
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
Expand Down Expand Up @@ -601,6 +586,21 @@ def _prepare_model_input(
seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype)
else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down

0 comments on commit 6599db4

Please sign in to comment.