Skip to content

Commit

Permalink
Cleanup code paths
Browse files Browse the repository at this point in the history
  • Loading branch information
kasohrab committed Sep 11, 2024
1 parent a430cf6 commit 4ec295f
Showing 1 changed file with 2 additions and 85 deletions.
87 changes: 2 additions & 85 deletions sarathi/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def forward(
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)

flashinfer_sample_result = []
if not do_top_p and not do_top_k:
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
flashinfer_sample_result = _sample_with_flashinfer(probs).cpu()
Expand Down Expand Up @@ -150,91 +149,9 @@ def _get_top_p_top_k(
return top_ps, top_ks


def _apply_top_p_top_k(
logits: torch.Tensor,
top_ps: List[float],
top_ks: List[int],
) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)

# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
logits_sort[top_p_mask] = -float("inf")

# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
logits_sort[top_k_mask] = -float("inf")

# Re-sort the probabilities.
logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1))
return logits


def _greedy_sample(
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
return torch.argmax(logprobs, dim=-1).view(-1).cpu().tolist()


def _random_sample(
probs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
random_samples = (
torch.multinomial(probs, num_samples=1, replacement=True)
.view(-1)
.cpu()
.tolist()
)

return random_samples


def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
seq_metadata_list: List[SequenceMetadata],
) -> SamplerOutputs:
categorized_seq_indices = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}

for i, seq_metadata in enumerate(seq_metadata_list):
sampling_type = seq_metadata.seq.sampling_params.sampling_type
categorized_seq_indices[sampling_type].append(i)
category_num_tokens[sampling_type] += 1

outputs: List[SamplerOutput] = [None] * len(seq_metadata_list)

for sampling_type in SamplingType:
seq_indices = categorized_seq_indices[sampling_type]
num_tokens = category_num_tokens[sampling_type]
if num_tokens == 0:
continue
category_logprobs = logprobs[seq_indices]
category_probs = probs[seq_indices]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(category_logprobs)
elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(category_probs)
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")

for seq_idx, sample_result in zip(seq_indices, sample_results):
seq_id = seq_metadata_list[seq_idx].seq.seq_id
outputs[seq_idx] = SamplerOutput(seq_id, sample_result)

return outputs


def _top_k_top_p_with_flashinfer(
logits: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
):
) -> torch.Tensor:
batch_size = logits.shape[0]
uniform_samples = torch.empty((_MAX_TOP_K_ROUND, batch_size), device=logits.device)
uniform_samples.uniform_()
Expand All @@ -252,7 +169,7 @@ def _top_k_top_p_with_flashinfer(
return batch_next_token_ids.view(-1)


def _sample_with_flashinfer(probs: torch.Tensor):
def _sample_with_flashinfer(probs: torch.Tensor) -> torch.Tensor:
batch_size = probs.shape[0]
uniform_samples = torch.rand(batch_size).to(probs.device)
samples = flashinfer_sampling_from_probs(probs, uniform_samples)
Expand Down

0 comments on commit 4ec295f

Please sign in to comment.