diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/logit_processor.hpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/logit_processor.hpp index 7b2e9f4a3..fd4f1444c 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/logit_processor.hpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/logit_processor.hpp @@ -8,15 +8,20 @@ #include "generation_config.hpp" -namespace LogitTransformers { +struct Token { + float m_log_prob = 0.; + int64_t m_index = 0; + + Token(float log_prob, int64_t index) : m_log_prob(log_prob), m_index(index) {} + Token() = default; +}; +namespace LogitTransformers { using TokenIds = std::vector; -using LogitWithIdx = std::pair; -using ProbabilityWithIdx = std::pair; class ILogitTransformer { public: - virtual std::vector apply(const std::vector& input_logits) = 0; + virtual std::vector apply(const std::vector& input_logits) = 0; void set_unique_generated_token_ids(const std::shared_ptr>& unique_generated_token_ids) { if (unique_generated_token_ids != nullptr) { @@ -61,13 +66,13 @@ class TopPFilter : public ILogitTransformer { m_value = top_p; } - std::vector apply(const std::vector& input_probs) override { - std::vector tmp(input_probs); - std::sort(tmp.begin(), tmp.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; }); + std::vector apply(const std::vector& input_probs) override { + std::vector tmp(input_probs); + std::sort(tmp.begin(), tmp.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); float probability_sum = 0.0f; size_t nucleus_size = 0; for (const auto& probability : tmp) { - probability_sum += probability.first; + probability_sum += probability.m_log_prob; nucleus_size += 1; if (probability_sum > m_value) break; } @@ -82,9 +87,9 @@ class TopKFilter : public ILogitTransformer { m_value = top_k; } - std::vector apply(const std::vector& input_probs) override { - std::vector tmp(input_probs); - std::sort(tmp.begin(), tmp.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; }); + std::vector apply(const std::vector& input_probs) override { + std::vector tmp(input_probs); + std::sort(tmp.begin(), tmp.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); size_t top_k = input_probs.size() >= m_value ? m_value : input_probs.size(); tmp.resize(top_k); return tmp; @@ -97,19 +102,19 @@ class TemperatureLogitTransform : public ILogitTransformer { m_value = temperature; }; - std::vector apply(const std::vector& input_logits) override { - std::vector output(input_logits.begin(), input_logits.end()); - std::sort(output.begin(), output.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; }); - float max_logit = output[0].first; + std::vector apply(const std::vector& input_logits) override { + std::vector output(input_logits.begin(), input_logits.end()); + std::sort(output.begin(), output.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); + float max_logit = output[0].m_log_prob; - std::for_each(output.begin(), output.end(), [max_logit, this](ProbabilityWithIdx& val) {val.first = expf((val.first - max_logit) / this->m_value);}); + std::for_each(output.begin(), output.end(), [max_logit, this](Token& val) {val.m_log_prob = expf((val.m_log_prob - max_logit) / this->m_value);}); float norm_sum = 0.0; for (const auto& val : output) { - norm_sum += val.first; + norm_sum += val.m_log_prob; } - std::for_each(output.begin(), output.end(), [norm_sum](ProbabilityWithIdx& val) {val.first /= norm_sum;}); + std::for_each(output.begin(), output.end(), [norm_sum](Token& val) {val.m_log_prob /= norm_sum;}); return output; } }; @@ -120,34 +125,34 @@ class RepetitionPenaltyTransform : public ILogitTransformer { m_value = value; }; - std::vector apply(const std::vector& input_logits) override { - std::vector output(input_logits.begin(), input_logits.end()); + std::vector apply(const std::vector& input_logits) override { + std::vector output(input_logits.begin(), input_logits.end()); size_t vocab_size = input_logits.size(); for (const auto& prompt_id : *m_unique_prompt_token_ids) { OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(input_logits[prompt_id].second == prompt_id, "input_logits must have original index order"); - auto logit_value = output[prompt_id].first; + OPENVINO_ASSERT(input_logits[prompt_id].m_index == prompt_id, "input_logits must have original index order"); + auto logit_value = output[prompt_id].m_log_prob; if (logit_value >= 0) { - output[prompt_id].first /= m_value; + output[prompt_id].m_log_prob /= m_value; } else { - output[prompt_id].first *= m_value; + output[prompt_id].m_log_prob *= m_value; }; } for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order"); - auto logit_value = output[input_id].first; + OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order"); + auto logit_value = output[input_id].m_log_prob; if (logit_value >= 0) { - output[input_id].first /= m_value; + output[input_id].m_log_prob /= m_value; } else { - output[input_id].first *= m_value; + output[input_id].m_log_prob *= m_value; }; } return output; } - std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { + std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { extract_generated_tokens(input_ids); return this->apply(input_logits); } @@ -159,24 +164,24 @@ class FrequencyPenaltyTransform : public ILogitTransformer { m_value = value; }; - std::vector apply(const std::vector& input_logits) override { - std::vector output(input_logits.begin(), input_logits.end()); + std::vector apply(const std::vector& input_logits) override { + std::vector output(input_logits.begin(), input_logits.end()); size_t vocab_size = input_logits.size(); for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order"); - auto logit_value = output[input_id].first; + OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order"); + auto logit_value = output[input_id].m_log_prob; if (logit_value >= 0) { - output[input_id].first -= m_value * input_id_pair.second; + output[input_id].m_log_prob -= m_value * input_id_pair.second; } else { - output[input_id].first += m_value * input_id_pair.second; + output[input_id].m_log_prob += m_value * input_id_pair.second; }; } return output; } - std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { + std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { extract_generated_tokens(input_ids); return this->apply(input_logits); } @@ -188,24 +193,24 @@ class PresencePenaltyTransform : public ILogitTransformer { m_value = value; }; - std::vector apply(const std::vector& input_logits) override { - std::vector output(input_logits.begin(), input_logits.end()); + std::vector apply(const std::vector& input_logits) override { + std::vector output(input_logits.begin(), input_logits.end()); size_t vocab_size = input_logits.size(); for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order"); - auto logit_value = output[input_id].first; + OPENVINO_ASSERT(input_logits[input_id].m_index == input_id, "input_logits must have original index order"); + auto logit_value = output[input_id].m_log_prob; if (logit_value >= 0) { - output[input_id].first -= m_value; + output[input_id].m_log_prob -= m_value; } else { - output[input_id].first += m_value; + output[input_id].m_log_prob += m_value; }; } return output; } - std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { + std::vector apply(const std::vector& input_logits, const TokenIds& input_ids) { extract_generated_tokens(input_ids); return this->apply(input_logits); } @@ -216,11 +221,11 @@ class ProbabilityNormalizeTransform : public ILogitTransformer { public: ProbabilityNormalizeTransform() = default; - std::vector apply(const std::vector& input_probs) override { - std::vector output(input_probs); + std::vector apply(const std::vector& input_probs) override { + std::vector output(input_probs); float norm_sum = 0.0; - for (const auto& val : output) norm_sum += val.first; - for (auto& val : output) val.first /= norm_sum; + for (const auto& val : output) norm_sum += val.m_log_prob; + for (auto& val : output) val.m_log_prob /= norm_sum; return output; } }; @@ -273,8 +278,8 @@ class LogitProcessor { } } - std::vector apply(const std::vector& logits) { - std::vector outputs(logits.begin(), logits.end()); + std::vector apply(const std::vector& logits) { + std::vector outputs(logits.begin(), logits.end()); for (const auto& transformer : m_logit_transformers) { outputs = transformer->apply(outputs); } diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/sampler.hpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/sampler.hpp index 7e45fb005..8723db11f 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/sampler.hpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/sampler.hpp @@ -60,11 +60,6 @@ std::vector kmp_search(const std::vector& haystack, const std: return res; } -struct Token { - float m_log_prob; - int64_t m_index; -}; - std::vector log_softmax(const ov::Tensor& logits, size_t batch_idx) { ov::Shape shape = logits.get_shape(); OPENVINO_ASSERT(shape.size() == 3); @@ -203,10 +198,9 @@ class GroupBeamSearcher { } }; -using LogitWithIdx = std::pair; class Sampler { - std::vector _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) { + std::vector _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) { ov::Shape logits_shape = logits.get_shape(); size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2]; OPENVINO_ASSERT(batch_idx <= batch_size); @@ -214,24 +208,24 @@ class Sampler { size_t sequence_offset = (seq_len - 1) * vocab_size; const float* logits_data = logits.data() + batch_offset + sequence_offset; - std::vector logit_vector(vocab_size); + std::vector logit_vector(vocab_size); for (size_t i = 0; i < logit_vector.size(); i++) { - logit_vector[i] = LogitWithIdx(logits_data[i], i); + logit_vector[i] = Token(logits_data[i], i); } return logit_vector; } - LogitWithIdx _greedy_sample(const std::vector& logit_vector) const { - auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const LogitWithIdx& lhs, const LogitWithIdx& rhs) { return lhs.first < rhs.first; }); + Token _greedy_sample(const std::vector& logit_vector) const { + auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; }); return *out_token; } - std::vector _multinomial_sample(const std::vector& logit_vector, size_t num_tokens_per_sequence) { + std::vector _multinomial_sample(const std::vector& logit_vector, size_t num_tokens_per_sequence) { std::vector multinomial_weights(logit_vector.size()); - for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].first; + for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].m_log_prob; auto dist = std::discrete_distribution(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1 - std::vector out_tokens; + std::vector out_tokens; for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) { size_t element_to_pick = dist(rng_engine); out_tokens.push_back(logit_vector[element_to_pick]); @@ -285,12 +279,12 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, if (sampling_params.is_greedy_sampling()) { OPENVINO_ASSERT(num_running_sequences == 1); } - auto register_new_token = [&](const LogitWithIdx& sampled_token_id, Sequence::Ptr running_sequence) { - logit_processor.register_new_generated_token(sampled_token_id.second); - running_sequence->append_token(sampled_token_id.second, sampled_token_id.first); + auto register_new_token = [&](const Token& sampled_token_id, Sequence::Ptr running_sequence) { + logit_processor.register_new_generated_token(sampled_token_id.m_index); + running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob); if (sampling_params.max_new_tokens == running_sequence->get_generated_len() || - sampled_token_id.second == sampling_params.eos_token_id && !sampling_params.ignore_eos) { + sampled_token_id.m_index == sampling_params.eos_token_id && !sampling_params.ignore_eos) { // stop sequence by max_new_tokens or EOS token running_sequence->set_status(SequenceStatus::FINISHED); // drop sequence from scheduler @@ -301,7 +295,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id); logit_vector = logit_processor.apply(logit_vector); - LogitWithIdx sampled_token_id; + Token sampled_token_id; if (sampling_params.is_greedy_sampling()) { sampled_token_id = _greedy_sample(logit_vector); } else { diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/logit_filtering.cpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/logit_filtering.cpp index e6e20d31a..159a65300 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/logit_filtering.cpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/logit_filtering.cpp @@ -10,8 +10,8 @@ using namespace LogitTransformers; struct TemperatureTransformTestStruct { float temperature; - std::vector input; - std::vector expected_output; + std::vector input; + std::vector expected_output; }; using TemperatureTransformTest = testing::TestWithParam; @@ -21,10 +21,10 @@ TEST_P(TemperatureTransformTest, TransformResultEqualToReference) { auto transform = TemperatureLogitTransform(test_struct.temperature); auto test_result = transform.apply(test_struct.input); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); - std::sort(test_result.begin(), test_result.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; }); + std::sort(test_result.begin(), test_result.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } } @@ -43,8 +43,8 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs, struct TopPTestStruct { float top_p; - std::vector input; - std::vector expected_output; + std::vector input; + std::vector expected_output; }; using TopPFilteringTest = testing::TestWithParam; @@ -55,8 +55,8 @@ TEST_P(TopPFilteringTest, FilterResultEqualToReference) { auto test_result = transform.apply(test_struct.input); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } } @@ -75,8 +75,8 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs, struct TopKTestStruct { size_t top_k; - std::vector input; - std::vector expected_output; + std::vector input; + std::vector expected_output; }; using TopKFilteringTest = testing::TestWithParam; @@ -87,8 +87,8 @@ TEST_P(TopKFilteringTest, FilterResultEqualToReference) { auto test_result = transform.apply(test_struct.input); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } } @@ -105,8 +105,8 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs, struct ProbabilityNormalizeTransformTestStruct { - std::vector input; - std::vector expected_output; + std::vector input; + std::vector expected_output; }; using ProbabilityNormalizeTransformTest = testing::TestWithParam; @@ -117,8 +117,8 @@ TEST_P(ProbabilityNormalizeTransformTest, TransformResultEqualToReference) { auto test_result = transform.apply(test_struct.input); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } } @@ -134,9 +134,9 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs, struct RepetitionPenaltyTransformTestStruct { float penalty; - std::vector input_logits; + std::vector input_logits; TokenIds input_ids; - std::vector expected_output; + std::vector expected_output; }; using RepetitionPenaltyTransformTest = testing::TestWithParam; @@ -147,8 +147,8 @@ TEST_P(RepetitionPenaltyTransformTest, TransformResultEqualToReference) { auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } } @@ -186,9 +186,9 @@ TEST(RepetitionPenaltyTransformInitializationTest, ThrowsForInvalidInputIds) { struct FrequencyPenaltyTransformTestStruct { float penalty; - std::vector input_logits; + std::vector input_logits; TokenIds input_ids; - std::vector expected_output; + std::vector expected_output; }; using FrequencyPenaltyTransformTest = testing::TestWithParam; @@ -199,8 +199,8 @@ TEST_P(FrequencyPenaltyTransformTest, TransformResultEqualToReference) { auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } }; @@ -239,9 +239,9 @@ TEST(FrequencyPenaltyTransformInitializationTest, ThrowsForInvalidInputIds) { struct PresencePenaltyTransformTestStruct { float penalty; - std::vector input_logits; + std::vector input_logits; TokenIds input_ids; - std::vector expected_output; + std::vector expected_output; }; using PresencePenaltyTransformTest = testing::TestWithParam; @@ -252,8 +252,8 @@ TEST_P(PresencePenaltyTransformTest, TransformResultEqualToReference) { auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids); ASSERT_EQ(test_result.size(), test_struct.expected_output.size()); for (size_t i = 0; i < test_result.size(); i++) { - EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6); - EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second); + EXPECT_NEAR(test_result[i].m_log_prob, test_struct.expected_output[i].m_log_prob, 1e-6); + EXPECT_EQ(test_result[i].m_index, test_struct.expected_output[i].m_index); } };