Skip to content

Commit

Permalink
Extend in beam support
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed May 27, 2024
1 parent edc53e5 commit b5a9f28
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ GenerationConfig GenerationConfig::greedy() {
GenerationConfig greedy_params;
greedy_params.temperature = 0.0f;
greedy_params.ignore_eos = true;
greedy_params.num_return_sequences = 1;
return greedy_params;
}

GenerationConfig GenerationConfig::beam_search() {
GenerationConfig beam_search;
beam_search.num_groups = 2;
beam_search.num_return_sequences = 3;
beam_search.group_size = 2;
beam_search.max_new_tokens = 100;
beam_search.diversity_penalty = 2.0f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,16 @@ class Sampler {
return out_token;
}

std::vector<int64_t> _multinomial_sample(ov::Tensor logits, float temperature, float top_p, size_t top_k) {
std::vector<int64_t> _multinomial_sample(ov::Tensor logits, float temperature, float top_p, size_t top_k, size_t n) {
std::vector<int64_t> out_tokens;
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
for (size_t i = 0; i < batch_size; ++i) {
const float * logits_data = logits.data<const float>() + (seq_len - 1) * vocab_size * i;
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
size_t batch_offset = batch_idx * seq_len * vocab_size;
size_t sequence_offset = (seq_len - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;

// const float * logits_data = &logits_data_group[i * seq_tensor_size] + seq_tensor_size;
std::vector<LogitWithIdx> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = LogitWithIdx(logits_data[i], i);
Expand All @@ -327,11 +331,14 @@ class Sampler {
for (size_t i = 0; i < filtered.size(); i++) multinomial_weights[i] = filtered[i].first;

auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
size_t element_to_pick = dist(rng_engine);
int64_t out_token = filtered[element_to_pick].second;
out_tokens.push_back(out_token);
for (size_t i = 0; i < n; ++i) {
size_t element_to_pick = dist(rng_engine);
int64_t out_token = filtered[element_to_pick].second;
out_tokens.push_back(out_token);
}
}

OPENVINO_ASSERT(out_tokens.size() == n || out_tokens.size() == batch_size);
return out_tokens;
}

Expand Down Expand Up @@ -360,12 +367,6 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

const GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();
if (sequence_group->requires_sampling() && sampling_params.is_multinomial() && sequence_group->get_num_processed_tokens() == 0) {
for (size_t i = 1; i < sampling_params.num_return_sequences; ++i) {
sequence_group->add_sequence(Sequence::create());
}
}

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
Expand All @@ -382,8 +383,19 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
sampled_token_ids.push_back(_greedy_sample(sequence_group_logits));
} else {
// is_multinomial()
OPENVINO_ASSERT(running_sequences.size() <= sampling_params.num_return_sequences);
sampled_token_ids = _multinomial_sample(sequence_group_logits, sampling_params.temperature, sampling_params.top_p, sampling_params.top_k);
auto n = sequence_group->get_num_processed_tokens() > 0 ? 1 : sampling_params.num_return_sequences;
sampled_token_ids = _multinomial_sample(sequence_group_logits, sampling_params.temperature, sampling_params.top_p, sampling_params.top_k, n);

if (n > 1) {
const auto sequence_to_fork = running_sequences[0];
std::list<uint64_t> forked_seq_ids;
for (; num_running_sequences < n; ++num_running_sequences) {
const auto forked_sequence = sequence_group->fork_sequence(sequence_to_fork);
forked_seq_ids.push_back(forked_sequence->get_id());
running_sequences.push_back(forked_sequence);
}
sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids});
}
}
for (size_t i = 0; i < num_running_sequences; ++i) {
// in case of greedy search we always have a single parent sequence to sample from
Expand Down Expand Up @@ -424,6 +436,20 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// stop sequence by max_new_tokens
m_beam_search_info.at(request_id).finalize(sampler_output);
}
const auto& finished_sequences = sequence_group->get_finished_sequences();
if (finished_sequences.size() > sampling_params.num_return_sequences) {
// save only `sampling_params.num_return_sequences` sequences in result
std::map<float, size_t> probs;
for (const auto& finished_sequence : finished_sequences) {
probs.insert({finished_sequence->get_cumulative_log_probs(), finished_sequence->get_id()});
}
auto it = probs.begin();
std::advance(it, sampling_params.num_return_sequences);
while (it != probs.end()) {
sequence_group->remove_sequence(it->second);
++it;
}
}
}
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
Expand Down

0 comments on commit b5a9f28

Please sign in to comment.