Skip to content

Commit

Permalink
feat: add accept num, emit num metric for ChainSpeculativeSampling (#450
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LiuXiaoxuanPKU authored Aug 17, 2024
1 parent 86c9e55 commit fa38b5e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 25 deletions.
43 changes: 33 additions & 10 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1154,8 +1154,10 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
typename DType, typename IdType>
__global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
DType* uniform_samples, DType* target_probs,
IdType* output_token_ids, uint32_t num_speculative_tokens,
uint32_t d) {
IdType* output_token_ids,
IdType* output_accepted_token_num,
IdType* output_emitted_token_num,
uint32_t num_speculative_tokens, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;

Expand All @@ -1165,20 +1167,38 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

uint32_t pos = 0;
for (pos = 0; pos < num_speculative_tokens; ++pos) {
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + pos) * d + draft_id],
p = draft_probs[(row_idx * num_speculative_tokens + pos) * d + draft_id];
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + pos];
uint32_t pos = num_speculative_tokens;
for (uint32_t i = 0; i < num_speculative_tokens; ++i) {
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
if (u * p < q) {
// accept the draft models output
output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = draft_id;
output_token_ids[row_idx * (num_speculative_tokens + 1) + i] = draft_id;
} else {
pos = i;
break;
}
}

uint32_t emitted_token_num = pos;
uint32_t accepted_token_num = pos;
for (uint32_t i = pos; i < num_speculative_tokens; ++i) {
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
if (u * p < q) {
++accepted_token_num;
}
}

if (tx == 0) {
output_accepted_token_num[row_idx] += accepted_token_num;
output_emitted_token_num[row_idx] += emitted_token_num;
}

// sample from relu(target_probs - draft_probs)
DType sum_relu_q_minus_p(0);
vec_t<DType, VEC_SIZE> q_vec, p_vec;
Expand Down Expand Up @@ -1284,7 +1304,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
template <typename DType, typename IdType>
cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
DType* uniform_samples, DType* target_probs,
IdType* output_token_ids, uint32_t batch_size,
IdType* output_token_ids, IdType* output_accepted_token_num,
IdType* output_emitted_token_num, uint32_t batch_size,
uint32_t num_speculative_tokens, uint32_t d,
bool deterministic, cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
Expand All @@ -1299,6 +1320,8 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids
&uniform_samples,
&target_probs,
&output_token_ids,
&output_accepted_token_num,
&output_emitted_token_num,
&num_speculative_tokens,
&d};
DISPATCH_ALIGNED_VEC_SIZE(
Expand Down
7 changes: 4 additions & 3 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, double eps);

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);
std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);

torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

Expand Down
27 changes: 21 additions & 6 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
return mask_logits;
}

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic) {
std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
Expand Down Expand Up @@ -349,14 +350,28 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(device));

bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value();
bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value();
auto output_accepted_token_num = maybe_output_accepted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
auto output_emitted_token_num = maybe_output_emitted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
if (has_output_accepted_token_num) {
CHECK_EQ(has_output_emitted_token_num, true);
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
}

cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
deterministic, torch_current_stream);
static_cast<int*>(output_token_ids.data_ptr()),
static_cast<int*>(output_accepted_token_num.data_ptr()),
static_cast<int*>(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens,
vocab_size, deterministic, torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));

return output_token_ids;
return {output_token_ids, output_accepted_token_num, output_emitted_token_num};
}
19 changes: 18 additions & 1 deletion python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ def chain_speculative_sampling(
draft_token_ids,
uniform_samples,
target_probs,
maybe_output_accepted_token_num: torch.Tensor = None,
maybe_output_emitted_token_num: torch.Tensor = None,
deterministic: bool = True,
) -> torch.Tensor:
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
Expand All @@ -614,6 +616,15 @@ def chain_speculative_sampling(
Compared to input :attr:`draft_probs`, the target model's probability has an additional
slot at the end because the target model will generate one more token than the draft model.
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
maybe_output_accepted_token_num: torch.Tensor
The number of tokens that can be accepted if each token is considered independently for each request.
This metric does not consider the fact that rejection sampling will stop at the first token that does not
satisfy the probablity requirement r < p/q.
It only evaluates the alignment of draft model and target model.
Shape: ``(batch_size)``
maybe_output_emitted_token_num: torch.Tensor
The number of tokens that are finally emitted/generated for each request.
Shape: ``(batch_size)``
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Expand All @@ -628,5 +639,11 @@ def chain_speculative_sampling(
Shape: (batch_size, num_specutate_tokens + 1)
"""
return _kernels.chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic
draft_probs,
draft_token_ids,
uniform_samples,
target_probs,
maybe_output_accepted_token_num,
maybe_output_emitted_token_num,
deterministic,
)
36 changes: 31 additions & 5 deletions python/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,17 @@ def test_chain_speculative_sampling(
# NOTE(Zihao): this is a very simple test that only checks whether output is valid or not.
for trials in range(10):
uniform_samples.uniform_()
output_token_ids = flashinfer.sampling.chain_speculative_sampling(
normalized_draft_prob,
draft_token_ids,
uniform_samples,
target_onehot_prob,
accepted_num = torch.zeros(batch_size, dtype=torch.int32).to(0)
emitted_num = torch.zeros(batch_size, dtype=torch.int32).to(0)
output_token_ids, accepted_num, emitted_num = (
flashinfer.sampling.chain_speculative_sampling(
normalized_draft_prob,
draft_token_ids,
uniform_samples,
target_onehot_prob,
accepted_num,
emitted_num,
)
)
if onehot_target:
assert torch.all(output_token_ids == target_token_ids)
Expand All @@ -359,6 +365,26 @@ def test_chain_speculative_sampling(
# from the second mismatched token on, the output tokens should be -1
assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1)

assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1))
batch_indices = torch.arange(batch_size, device=normalized_draft_prob.device)[
:, None
]
probs_indicies = torch.arange(
num_speculate_tokens, device=normalized_draft_prob.device
)
selected_draft_probs = normalized_draft_prob[
batch_indices, probs_indicies, draft_token_ids
]
selected_target_probs = target_onehot_prob[
batch_indices, probs_indicies, draft_token_ids
]
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1,), 1, device=normalized_draft_prob.device),
)
ref_accepted = (uniform_samples[:, :-1] < capped_ratio).sum(dim=1)
assert torch.all(accepted_num == ref_accepted)


if __name__ == "__main__":
test_sampling(1, 111)
Expand Down

0 comments on commit fa38b5e

Please sign in to comment.