Skip to content

Commit

Permalink
Remove extra token desc
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jun 11, 2024
1 parent a64f30a commit adec0e0
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@

#include "generation_config.hpp"

namespace LogitTransformers {
struct Token {
float m_log_prob = 0.;
int64_t m_index = 0;

Token(float log_prob, int64_t index) : m_log_prob(log_prob), m_index(index) {}
Token() = default;
};

namespace LogitTransformers {
using TokenIds = std::vector<int64_t>;
using LogitWithIdx = std::pair<float, size_t>;
using ProbabilityWithIdx = std::pair<float, size_t>;

class ILogitTransformer {
public:
virtual std::vector<ProbabilityWithIdx> apply(const std::vector<ProbabilityWithIdx>& input_logits) = 0;
virtual std::vector<Token> apply(const std::vector<Token>& input_logits) = 0;

void set_unique_generated_token_ids(const std::shared_ptr<std::map<int64_t, size_t>>& unique_generated_token_ids) {
if (unique_generated_token_ids != nullptr) {
Expand Down Expand Up @@ -61,13 +66,13 @@ class TopPFilter : public ILogitTransformer {
m_value = top_p;
}

std::vector<ProbabilityWithIdx> apply(const std::vector<ProbabilityWithIdx>& input_probs) override {
std::vector<ProbabilityWithIdx> tmp(input_probs);
std::sort(tmp.begin(), tmp.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; });
std::vector<Token> apply(const std::vector<Token>& input_probs) override {
std::vector<Token> tmp(input_probs);
std::sort(tmp.begin(), tmp.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
float probability_sum = 0.0f;
size_t nucleus_size = 0;
for (const auto& probability : tmp) {
probability_sum += probability.first;
probability_sum += probability.m_log_prob;
nucleus_size += 1;
if (probability_sum > m_value) break;
}
Expand All @@ -82,9 +87,9 @@ class TopKFilter : public ILogitTransformer {
m_value = top_k;
}

std::vector<ProbabilityWithIdx> apply(const std::vector<ProbabilityWithIdx>& input_probs) override {
std::vector<ProbabilityWithIdx> tmp(input_probs);
std::sort(tmp.begin(), tmp.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; });
std::vector<Token> apply(const std::vector<Token>& input_probs) override {
std::vector<Token> tmp(input_probs);
std::sort(tmp.begin(), tmp.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
size_t top_k = input_probs.size() >= m_value ? m_value : input_probs.size();
tmp.resize(top_k);
return tmp;
Expand All @@ -97,19 +102,19 @@ class TemperatureLogitTransform : public ILogitTransformer {
m_value = temperature;
};

std::vector<ProbabilityWithIdx> apply(const std::vector<LogitWithIdx>& input_logits) override {
std::vector<ProbabilityWithIdx> output(input_logits.begin(), input_logits.end());
std::sort(output.begin(), output.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; });
float max_logit = output[0].first;
std::vector<Token> apply(const std::vector<Token>& input_logits) override {
std::vector<Token> output(input_logits.begin(), input_logits.end());
std::sort(output.begin(), output.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
float max_logit = output[0].m_log_prob;

std::for_each(output.begin(), output.end(), [max_logit, this](ProbabilityWithIdx& val) {val.first = expf((val.first - max_logit) / this->m_value);});
std::for_each(output.begin(), output.end(), [max_logit, this](Token& val) {val.m_log_prob = expf((val.m_log_prob - max_logit) / this->m_value);});

float norm_sum = 0.0;
for (const auto& val : output) {
norm_sum += val.first;
norm_sum += val.m_log_prob;
}

std::for_each(output.begin(), output.end(), [norm_sum](ProbabilityWithIdx& val) {val.first /= norm_sum;});
std::for_each(output.begin(), output.end(), [norm_sum](Token& val) {val.m_log_prob /= norm_sum;});
return output;
}
};
Expand All @@ -120,34 +125,34 @@ class RepetitionPenaltyTransform : public ILogitTransformer {
m_value = value;
};

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits) override {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
std::vector<Token> apply(const std::vector<Token>& input_logits) override {
std::vector<Token> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (const auto& prompt_id : *m_unique_prompt_token_ids) {
OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[prompt_id].second == prompt_id, "input_logits must have original index order");
auto logit_value = output[prompt_id].first;
OPENVINO_ASSERT(input_logits[prompt_id].m_index == prompt_id, "input_logits must have original index order");
auto logit_value = output[prompt_id].m_log_prob;
if (logit_value >= 0) {
output[prompt_id].first /= m_value;
output[prompt_id].m_log_prob /= m_value;
} else {
output[prompt_id].first *= m_value;
output[prompt_id].m_log_prob *= m_value;
};
}
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].m_log_prob;
if (logit_value >= 0) {
output[input_id].first /= m_value;
output[input_id].m_log_prob /= m_value;
} else {
output[input_id].first *= m_value;
output[input_id].m_log_prob *= m_value;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::vector<Token> apply(const std::vector<Token>& input_logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
return this->apply(input_logits);
}
Expand All @@ -159,24 +164,24 @@ class FrequencyPenaltyTransform : public ILogitTransformer {
m_value = value;
};

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits) override {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
std::vector<Token> apply(const std::vector<Token>& input_logits) override {
std::vector<Token> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].m_log_prob;
if (logit_value >= 0) {
output[input_id].first -= m_value * input_id_pair.second;
output[input_id].m_log_prob -= m_value * input_id_pair.second;
} else {
output[input_id].first += m_value * input_id_pair.second;
output[input_id].m_log_prob += m_value * input_id_pair.second;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::vector<Token> apply(const std::vector<Token>& input_logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
return this->apply(input_logits);
}
Expand All @@ -188,24 +193,24 @@ class PresencePenaltyTransform : public ILogitTransformer {
m_value = value;
};

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits) override {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
std::vector<Token> apply(const std::vector<Token>& input_logits) override {
std::vector<Token> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].m_log_prob;
if (logit_value >= 0) {
output[input_id].first -= m_value;
output[input_id].m_log_prob -= m_value;
} else {
output[input_id].first += m_value;
output[input_id].m_log_prob += m_value;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::vector<Token> apply(const std::vector<Token>& input_logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
return this->apply(input_logits);
}
Expand All @@ -216,11 +221,11 @@ class ProbabilityNormalizeTransform : public ILogitTransformer {
public:
ProbabilityNormalizeTransform() = default;

std::vector<ProbabilityWithIdx> apply(const std::vector<ProbabilityWithIdx>& input_probs) override {
std::vector<ProbabilityWithIdx> output(input_probs);
std::vector<Token> apply(const std::vector<Token>& input_probs) override {
std::vector<Token> output(input_probs);
float norm_sum = 0.0;
for (const auto& val : output) norm_sum += val.first;
for (auto& val : output) val.first /= norm_sum;
for (const auto& val : output) norm_sum += val.m_log_prob;
for (auto& val : output) val.m_log_prob /= norm_sum;
return output;
}
};
Expand Down Expand Up @@ -273,8 +278,8 @@ class LogitProcessor {
}
}

std::vector<LogitTransformers::ProbabilityWithIdx> apply(const std::vector<LogitTransformers::ProbabilityWithIdx>& logits) {
std::vector<LogitTransformers::ProbabilityWithIdx> outputs(logits.begin(), logits.end());
std::vector<Token> apply(const std::vector<Token>& logits) {
std::vector<Token> outputs(logits.begin(), logits.end());
for (const auto& transformer : m_logit_transformers) {
outputs = transformer->apply(outputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std:
return res;
}

struct Token {
float m_log_prob;
int64_t m_index;
};

std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
ov::Shape shape = logits.get_shape();
OPENVINO_ASSERT(shape.size() == 3);
Expand Down Expand Up @@ -203,35 +198,34 @@ class GroupBeamSearcher {
}
};

using LogitWithIdx = std::pair<float, size_t>;
class Sampler {

std::vector<LogitWithIdx> _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
std::vector<Token> _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
OPENVINO_ASSERT(batch_idx <= batch_size);
size_t batch_offset = batch_idx * seq_len * vocab_size;
size_t sequence_offset = (seq_len - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;

std::vector<LogitWithIdx> logit_vector(vocab_size);
std::vector<Token> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = LogitWithIdx(logits_data[i], i);
logit_vector[i] = Token(logits_data[i], i);
}
return logit_vector;
}

LogitWithIdx _greedy_sample(const std::vector<LogitWithIdx>& logit_vector) const {
auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const LogitWithIdx& lhs, const LogitWithIdx& rhs) { return lhs.first < rhs.first; });
Token _greedy_sample(const std::vector<Token>& logit_vector) const {
auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; });
return *out_token;
}

std::vector<LogitWithIdx> _multinomial_sample(const std::vector<LogitWithIdx>& logit_vector, size_t num_tokens_per_sequence) {
std::vector<Token> _multinomial_sample(const std::vector<Token>& logit_vector, size_t num_tokens_per_sequence) {
std::vector<float> multinomial_weights(logit_vector.size());
for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].first;
for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].m_log_prob;

auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
std::vector<LogitWithIdx> out_tokens;
std::vector<Token> out_tokens;
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
size_t element_to_pick = dist(rng_engine);
out_tokens.push_back(logit_vector[element_to_pick]);
Expand Down Expand Up @@ -285,12 +279,12 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
if (sampling_params.is_greedy_sampling()) {
OPENVINO_ASSERT(num_running_sequences == 1);
}
auto register_new_token = [&](const LogitWithIdx& sampled_token_id, Sequence::Ptr running_sequence) {
logit_processor.register_new_generated_token(sampled_token_id.second);
running_sequence->append_token(sampled_token_id.second, sampled_token_id.first);
auto register_new_token = [&](const Token& sampled_token_id, Sequence::Ptr running_sequence) {
logit_processor.register_new_generated_token(sampled_token_id.m_index);
running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob);

if (sampling_params.max_new_tokens == running_sequence->get_generated_len() ||
sampled_token_id.second == sampling_params.eos_token_id && !sampling_params.ignore_eos) {
sampled_token_id.m_index == sampling_params.eos_token_id && !sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or EOS token
running_sequence->set_status(SequenceStatus::FINISHED);
// drop sequence from scheduler
Expand All @@ -301,7 +295,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id);
logit_vector = logit_processor.apply(logit_vector);

LogitWithIdx sampled_token_id;
Token sampled_token_id;
if (sampling_params.is_greedy_sampling()) {
sampled_token_id = _greedy_sample(logit_vector);
} else {
Expand Down
Loading

0 comments on commit adec0e0

Please sign in to comment.