Skip to content

Commit

Permalink
[Continuous batching] Late token vector initialization in sampling (#649
Browse files Browse the repository at this point in the history
)

Changes:
- Further split of greedy and multinomial paths - using original logits
buffer in greedy and whenever possible in multinomial sampling. Sorted
vector is created only when top_p or top_k filters need to be applied.
- Fixing issue with top_k filter being applied always when multinomial
sampling is used unless it's explicitly set to 0. Now default value
(which is max for size_t) will not trigger applying top_k filter. The
filter will also not be applied if top_k is bigger than logits vector
size.
- Skipping multinomial tests
  • Loading branch information
mzegla authored Aug 2, 2024
1 parent 3cb2829 commit 3304798
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 188 deletions.
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(generation_stream),
m_generation_stream(std::move(generation_stream)),
m_sampling_params(sampling_params) {};

~GenerationHandleImpl();
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class BlockManager {
}

bool can_append_slots(SequenceGroup::CPtr seq_group) {
return required_blocks_count(seq_group) <= m_allocator.num_free_blocks();
return required_blocks_count(std::move(seq_group)) <= m_allocator.num_free_blocks();
}

size_t required_blocks_count(SequenceGroup::CPtr seq_group) {
Expand Down Expand Up @@ -466,7 +466,7 @@ class BlockManager {
// write information about block forking for later usage in CacheManager
copy_blocks_map[last_block->get_index()].push_back(new_block->get_index());
// release `last_block` usage
m_allocator.free(last_block);
m_allocator.free(std::move(last_block));
} else {
// we are the only users of this block
if (m_enable_prefix_caching) {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GenerationStream {
}

void push(GenerationOutputs outputs) {
m_output_queue.push(outputs);
m_output_queue.push(std::move(outputs));
}

// Retriving vector of pairs <sequence_id, token_id> as we can generate multiple outputs for a single prompt
Expand Down
135 changes: 85 additions & 50 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,38 @@ struct Token {
Token() = default;
};

struct Logits {
float * m_data = nullptr;
size_t m_size;
// Late initialized for top_p or top_k transforms
std::vector<Token> m_vector;

Logits(float* data, size_t size): m_data(data), m_size(size) {}


void initialize_vector() {
OPENVINO_ASSERT(m_vector.size() == 0, "Logits vector already initialized");
m_vector.reserve(m_size);
for (size_t i = 0; i < m_size; i++)
m_vector.emplace_back(m_data[i], i);
}

bool is_vector_initialized() const {
return m_vector.size() > 0;
}

void resize(size_t new_size) {
m_size = new_size;
m_vector.resize(new_size);
}
};

namespace LogitTransformers {
using TokenIds = std::vector<int64_t>;

class ILogitTransformer {
public:
virtual void apply(std::vector<Token>& logits) = 0;
virtual void apply(Logits& logits) = 0;

virtual bool is_applicable(size_t generated_tokens_cnt = 0) {
return true;
Expand All @@ -32,11 +58,15 @@ class TopPFilter : public ILogitTransformer {
public:
TopPFilter(double top_p) : m_top_p(top_p) {}

void apply(std::vector<Token>& logits) override {
std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
void apply(Logits& logits) override {
if (!logits.is_vector_initialized()) {
// Initialize and sort vector
logits.initialize_vector();
std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
}
float probability_sum = 0.0f;
size_t nucleus_size = 0;
for (const auto& probability : logits) {
for (const auto& probability : logits.m_vector) {
probability_sum += probability.m_log_prob;
nucleus_size += 1;
if (probability_sum > m_top_p) break;
Expand All @@ -52,10 +82,18 @@ class TopKFilter : public ILogitTransformer {
public:
TopKFilter(size_t top_k) : m_top_k(top_k) {}

void apply(std::vector<Token>& logits) override {
std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
size_t top_k = logits.size() >= m_top_k ? m_top_k : logits.size();
logits.resize(top_k);
// If this transform is used along with top_p, it should be applied after it since top_p sorts entire vector and top_k does it only partially
void apply(Logits& logits) override {

if (m_top_k >= logits.m_size)
return;

if (!logits.is_vector_initialized()) {
// Initialize and partially sort vector
logits.initialize_vector();
std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
}
logits.resize(m_top_k);
}

protected:
Expand All @@ -66,18 +104,23 @@ class TemperatureLogitTransform : public ILogitTransformer {
public:
TemperatureLogitTransform(double temperature) : m_temperature(temperature) {};

void apply(std::vector<Token>& logits) override {
auto max_prob_token = std::max_element(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; });
float max_logit = max_prob_token->m_log_prob;

std::for_each(logits.begin(), logits.end(), [max_logit, this](Token& val) {val.m_log_prob = expf((val.m_log_prob - max_logit) / this->m_temperature);});
void apply(Logits& logits) override {
float max_logit = -std::numeric_limits<float>::infinity();
for (size_t i = 0; i < logits.m_size; i++) {
if (logits.m_data[i] > max_logit) {
max_logit = logits.m_data[i];
}
}

float norm_sum = 0.0;
for (const auto& val : logits) {
norm_sum += val.m_log_prob;
for (size_t i = 0; i < logits.m_size; i++) {
logits.m_data[i] = expf((logits.m_data[i] - max_logit) / this->m_temperature);
norm_sum += logits.m_data[i];
}

std::for_each(logits.begin(), logits.end(), [norm_sum](Token& val) {val.m_log_prob /= norm_sum;});
for (size_t i = 0; i < logits.m_size; i++) {
logits.m_data[i] /= norm_sum;
}
}

protected:
Expand Down Expand Up @@ -118,32 +161,28 @@ class RepetitionPenaltyTransform : public IPenaltyTransformer {
m_penalty = repetition_penalty;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& prompt_id : *m_unique_prompt_token_ids) {
OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[prompt_id].m_index == prompt_id, "input_logits must have original index order");
auto logit_value = logits[prompt_id].m_log_prob;
if (logit_value >= 0) {
logits[prompt_id].m_log_prob /= m_penalty;
if (logits.m_data[prompt_id] >= 0) {
logits.m_data[prompt_id] /= m_penalty;
} else {
logits[prompt_id].m_log_prob *= m_penalty;
logits.m_data[prompt_id] *= m_penalty;
};
}
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob /= m_penalty;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] /= m_penalty;
} else {
logits[input_id].m_log_prob *= m_penalty;
logits.m_data[input_id] *= m_penalty;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
set_unique_prompt_token_ids(nullptr);
extract_generated_tokens(input_ids);
apply(logits);
Expand All @@ -166,10 +205,10 @@ class EOSPenaltyTransform : public ILogitTransformer {
EOSPenaltyTransform(size_t eos_token_id, size_t min_generated_tokens) :
m_eos_token_id(eos_token_id), m_applicable_tensor_len(min_generated_tokens) {}

void apply(std::vector<Token>& logits) override {
// Since EOS penalty is applied early, the token vector is not sorted
void apply(Logits& logits) override {
// Since EOS penalty is applied early, the token vector is not initialized yet
// and we can assume element order match token ids.
logits[m_eos_token_id].m_log_prob = 0.f;
logits.m_data[m_eos_token_id] = 0.f;
}


Expand All @@ -188,22 +227,20 @@ class FrequencyPenaltyTransform : public IPenaltyTransformer {
m_penalty = value;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob -= m_penalty * input_id_pair.second;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] -= m_penalty * input_id_pair.second;
} else {
logits[input_id].m_log_prob += m_penalty * input_id_pair.second;
logits.m_data[input_id] += m_penalty * input_id_pair.second;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
apply(logits);
}
Expand All @@ -215,22 +252,20 @@ class PresencePenaltyTransform : public IPenaltyTransformer {
m_penalty = value;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob -= m_penalty;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] -= m_penalty;
} else {
logits[input_id].m_log_prob += m_penalty;
logits.m_data[input_id] += m_penalty;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
apply(logits);
}
Expand Down Expand Up @@ -286,14 +321,14 @@ class LogitProcessor {
if (sampling_params.top_p != 1.0f) {
m_logit_transformers.emplace_back(new LogitTransformers::TopPFilter(sampling_params.top_p));
}
if (sampling_params.top_k > 0) {
if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits<size_t>::max()) {
m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k));
}
}
}
}

void apply(std::vector<Token>& logits) {
void apply(Logits& logits) {
for (const auto& transformer : m_logit_transformers) {
if (transformer->is_applicable(m_generated_tokens)) {
transformer->apply(logits);
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ModelRunner {
SchedulerConfig m_scheduler_config;
public:
ModelRunner(ov::InferRequest request, const SchedulerConfig& scheduler_config) :
m_request(request),
m_request(std::move(request)),
m_scheduler_config(scheduler_config) { }

ov::InferRequest get_infer_request() const {
Expand Down
46 changes: 27 additions & 19 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct Beam {
float m_score = -std::numeric_limits<float>::infinity();

Beam(Sequence::Ptr sequence)
: m_sequence(sequence) { }
: m_sequence(std::move(sequence)) { }

size_t get_generated_len() const {
return m_sequence->get_generated_len();
Expand Down Expand Up @@ -203,40 +203,49 @@ class GroupBeamSearcher {

class Sampler {

std::vector<Token> _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
OPENVINO_ASSERT(batch_idx <= batch_size);
size_t batch_offset = batch_idx * seq_len * vocab_size;
size_t sequence_offset = (seq_len - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;

std::vector<Token> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = Token(logits_data[i], i);
}
return logit_vector;
return Logits{logits_data, vocab_size};
}

Token _greedy_sample(const std::vector<Token>& logit_vector) const {
Token max_token{-std::numeric_limits<float>::infinity() , 0};
for (const auto& logit : logit_vector) {
if (logit.m_log_prob > max_token.m_log_prob) {
max_token = logit;
Token _greedy_sample(const Logits& logits) const {
// For greedy sampling we do not expect sorting or shrinking considered tokens
// so we can operate directly on the data buffer
float max_value = -std::numeric_limits<float>::infinity();
size_t max_index = 0;
for (size_t i = 0; i < logits.m_size; ++i) {
if (logits.m_data[i] > max_value) {
max_value = logits.m_data[i];
max_index = i;
}
}
return max_token;
return Token(logits.m_data[max_index], max_index);
}

std::vector<Token> _multinomial_sample(const std::vector<Token>& logit_vector, size_t num_tokens_per_sequence) {
std::vector<float> multinomial_weights(logit_vector.size());
for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].m_log_prob;
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
std::vector<float> multinomial_weights;
multinomial_weights.reserve(logits.m_size);
if (logits.is_vector_initialized())
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
else
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);

auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1

std::vector<Token> out_tokens;
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
size_t element_to_pick = dist(rng_engine);
out_tokens.push_back(logit_vector[element_to_pick]);
if (logits.is_vector_initialized())
out_tokens.push_back(logits.m_vector[element_to_pick]);
else
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
}
return out_tokens;
}
Expand Down Expand Up @@ -296,7 +305,6 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) {
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id);
logit_processor.apply(logit_vector);

Token sampled_token_id;
if (sampling_params.is_greedy_decoding()) {
sampled_token_id = _greedy_sample(logit_vector);
Expand Down
Loading

0 comments on commit 3304798

Please sign in to comment.