Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Various simplifications and typing fixes #5368

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def process_outputs(self, sequence_group: SequenceGroup,

# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
samples = [output.samples[0] for output in outputs]

# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
Expand Down
39 changes: 12 additions & 27 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def score_proposals(

target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

Expand Down Expand Up @@ -140,8 +140,7 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -167,30 +166,16 @@ def _contract_batch(
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs

target_token_ids = target_token_ids.squeeze().reshape(
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)

all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
Expand Down
14 changes: 7 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
Expand Down Expand Up @@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None

target_worker = Worker(*args, **kwargs)
Expand Down Expand Up @@ -109,12 +110,11 @@ def create_worker(
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))

return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, ))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens))

def __init__(
self,
Expand Down
44 changes: 17 additions & 27 deletions vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _split_by_proposal_len(
nonzero_proposal_len_indices,
)

def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
@staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
Expand Down Expand Up @@ -207,7 +208,7 @@ def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
Expand All @@ -218,25 +219,18 @@ def _merge_outputs(
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor

sampler_output = maybe_sampler_output
Expand All @@ -246,18 +240,14 @@ def _merge_outputs(
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]

entire_proposal_tokens = torch.full(
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs

Expand Down
11 changes: 3 additions & 8 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple

import torch

from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SequenceOutput)

SeqId = int

Expand All @@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


def get_all_num_logprobs(
Expand Down Expand Up @@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.

Args:
Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Type

from transformers import PretrainedConfig

Expand All @@ -9,7 +9,7 @@

logger = init_logger(__name__)

_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
Expand Down Expand Up @@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
return config
30 changes: 15 additions & 15 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,28 +527,13 @@ def _prepare_model_input(
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensors aren't used in the flashinfer case

torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
Expand Down Expand Up @@ -601,6 +586,21 @@ def _prepare_model_input(
seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype)
else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)

torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
Loading