Skip to content

Commit

Permalink
Enable frequency and presence penalties
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jun 5, 2024
1 parent 76148c5 commit 4f73d36
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ struct GenerationConfig {
StopCriteria stop_criteria = StopCriteria::HEURISTIC;
size_t num_return_sequences = 3; // is used by beam search, in other case is equal to batch size

float repetition_penalty = 1.0f;
float repetition_penalty = 1.0f; // based on token repetition in prompt and generated tests
float presence_penalty = 0.0f; // based on token repetition and generated tests
float frequence_penalty = 0.0f; // based on quantity token repetition and generated tests
float length_penalty = 1.0f;
size_t no_repeat_ngram_size = std::numeric_limits<size_t>::max();
std::function<bool(const Sequence&)> early_finish = [] (const Sequence&) { return false; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,7 @@ GenerationConfig GenerationConfig::multinomial() {
multinomial.temperature = 0.8f;
multinomial.top_p = 0.8;
multinomial.top_k = 20;
multinomial.presence_penalty = 0.01f;
multinomial.frequence_penalty = 0.1f;
return multinomial;
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,23 @@ class RepetitionPenaltyTransform {
OPENVINO_ASSERT(m_penalty >= 0.0f, "repetition penalty must be a positive value");
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const std::set<int64_t>& unique_input_ids) {
std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits,
const std::map<int64_t, size_t>& unique_input_ids,
const std::set<int64_t>& unique_prompt_ids = {}) {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (auto input_id : unique_input_ids) {
for (const auto& prompt_id : unique_prompt_ids) {
OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[prompt_id].second == prompt_id, "input_logits must have original index order");
auto logit_value = output[prompt_id].first;
if (logit_value >= 0) {
output[prompt_id].first /= m_penalty;
} else {
output[prompt_id].first *= m_penalty;
};
}
for (const auto& input_id_pair : unique_input_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
Expand All @@ -300,9 +313,97 @@ class RepetitionPenaltyTransform {
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::set<int64_t> unique_input_ids(input_ids.begin(), input_ids.end());
std::map<int64_t, size_t> unique_input_ids;
for (const auto& input_id : input_ids) {
if (unique_input_ids.count(input_id)) {
unique_input_ids[input_id]++;
} else {
unique_input_ids.insert({input_id, 1});
}
}
return this->apply(input_logits, unique_input_ids);
}

private:
double m_penalty;
};

class FrequencyPenaltyTransform {
public:
FrequencyPenaltyTransform(double penalty) : m_penalty(penalty) {
OPENVINO_ASSERT(m_penalty >= -2.0f && m_penalty <= 2.0f, "repetition penalty must be a positive value");
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits,
const std::map<int64_t, size_t>& unique_input_ids) {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (const auto& input_id_pair : unique_input_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
if (logit_value >= 0) {
output[input_id].first -= m_penalty * input_id_pair.second;
} else {
output[input_id].first += m_penalty * input_id_pair.second;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::map<int64_t, size_t> unique_input_ids;
for (const auto& input_id : input_ids) {
if (unique_input_ids.count(input_id)) {
unique_input_ids[input_id]++;
} else {
unique_input_ids.insert({input_id, 1});
}
}
return this->apply(input_logits, unique_input_ids);
}

private:
double m_penalty;
};

class PresencePenaltyTransform {
public:
PresencePenaltyTransform(double penalty) : m_penalty(penalty) {
OPENVINO_ASSERT(m_penalty >= -2.0f && m_penalty <= 2.0f, "repetition penalty must be a positive value");
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits,
const std::map<int64_t, size_t>& unique_input_ids) {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (const auto& input_id_pair : unique_input_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
if (logit_value >= 0) {
output[input_id].first -= m_penalty;
} else {
output[input_id].first += m_penalty;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::map<int64_t, size_t> unique_input_ids;
for (const auto& input_id : input_ids) {
if (unique_input_ids.count(input_id)) {
unique_input_ids[input_id]++;
} else {
unique_input_ids.insert({input_id, 1});
}
}
return this->apply(input_logits, unique_input_ids);
}

private:
double m_penalty;
};
Expand Down Expand Up @@ -405,8 +506,17 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,

if (sampling_params.repetition_penalty != 1.0f) {
auto repetition_penalty_transform = RepetitionPenaltyTransform(sampling_params.repetition_penalty);
logit_vector = repetition_penalty_transform.apply(logit_vector, sequence_group->get_unique_generated_ids());
logit_vector = repetition_penalty_transform.apply(logit_vector, sequence_group->get_unique_generated_ids(), sequence_group->get_unique_prompt_ids());
}
if (sampling_params.presence_penalty != 0.0f) {
auto presence_penalty_transform = PresencePenaltyTransform(sampling_params.presence_penalty);
logit_vector = presence_penalty_transform.apply(logit_vector, sequence_group->get_unique_generated_ids());
}
if (sampling_params.frequence_penalty != 0.0f) {
auto frequence_penalty_transform = FrequencyPenaltyTransform(sampling_params.frequence_penalty);
logit_vector = frequence_penalty_transform.apply(logit_vector, sequence_group->get_unique_generated_ids());
}

std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class SequenceGroup {
GenerationConfig m_sampling_params;
std::size_t m_block_size;
TokenIds m_prompt_ids;
std::set<int64_t> m_unique_generated_ids;
std::map<int64_t, size_t> m_unique_generated_ids;
std::set<int64_t> m_unique_prompt_ids;
GenerationStream::Ptr m_generation_stream;

uint64_t m_next_sequence_id = 0;
Expand Down Expand Up @@ -167,7 +168,9 @@ class SequenceGroup {

m_prompt_ids.resize(input_ids.get_size());
std::copy_n(input_ids.data<int64_t>(), input_ids.get_size(), m_prompt_ids.begin());
for (auto id: m_prompt_ids) { m_unique_generated_ids.insert(id); }
for (auto id: m_prompt_ids) {
m_unique_prompt_ids.insert(id);
}
}

void add_sequence(const Sequence::Ptr & sequence) {
Expand Down Expand Up @@ -329,12 +332,20 @@ class SequenceGroup {
return m_prompt_ids;
}

const std::set<int64_t>& get_unique_generated_ids() const {
const std::map<int64_t, size_t>& get_unique_generated_ids() const {
return m_unique_generated_ids;
}

const std::set<int64_t>& get_unique_prompt_ids() const {
return m_unique_prompt_ids;
}

void register_generated_token_id(int64_t token_id) {
m_unique_generated_ids.insert(token_id);
if (m_unique_generated_ids.count(token_id)) {
m_unique_generated_ids[token_id]++;
} else {
m_unique_generated_ids.insert({token_id, 1});
}
}

size_t get_num_logical_blocks() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,124 @@ TEST(RepetitionPenaltyTransformInitializationTest, ThrowsForInvalidPenalties) {

TEST(RepetitionPenaltyTransformInitializationTest, ThrowsForInvalidInputIds) {
auto transform = RepetitionPenaltyTransform(1.5);
EXPECT_THROW(transform.apply({ {43.0f, 0} }, std::set<int64_t>{1337} ), ov::Exception);
EXPECT_THROW(transform.apply({ {18.0f, 0} }, std::set<int64_t>{0, -1} ), ov::Exception);
EXPECT_THROW(transform.apply({ {43.0f, 0} }, std::map<int64_t, size_t>{{1337, 0}} ), ov::Exception);
EXPECT_THROW(transform.apply({ {18.0f, 0} }, std::map<int64_t, size_t>{{0, 1}, {-1, 1}} ), ov::Exception);
}

// ===================
struct FrequencyPenaltyTransformTestStruct {
float penalty;
std::vector<LogitWithIdx> input_logits;
TokenIds input_ids;
std::vector<LogitWithIdx> expected_output;
};

using FrequencyPenaltyTransformTest = testing::TestWithParam<FrequencyPenaltyTransformTestStruct>;

TEST_P(FrequencyPenaltyTransformTest, TransformResultEqualToReference) {
auto test_struct = GetParam();
auto transform = FrequencyPenaltyTransform(test_struct.penalty);
auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids);
ASSERT_EQ(test_result.size(), test_struct.expected_output.size());
for (size_t i = 0; i < test_result.size(); i++) {
EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6);
EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second);
}
};


const std::vector<FrequencyPenaltyTransformTestStruct> FREQUENCY_PENALTY_TRANSFORM_TEST_CASES = {
{ // basic case, indices are applied, order is left as-is
0.5f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 1, 0 },
{ {-0.5f, 0}, {1.5f, 1}, {3.0f, 2} }
},
{ // negative scores case
-0.6f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 0, 1, 1 },
{ {-1.6f, 0}, {3.2f, 1}, {3.0f, 2} }
},
{ // repeated tokens in prompt, check that the penalty is only applied once
0.2f,
{ {1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 2, 0, 2 },
{ {0.8f, 0}, {2.0f, 1}, {2.6f, 2} }
},
};

INSTANTIATE_TEST_SUITE_P(VariousInputs,
FrequencyPenaltyTransformTest,
testing::ValuesIn(FREQUENCY_PENALTY_TRANSFORM_TEST_CASES));


TEST(FrequencyPenaltyTransformInitializationTest, ThrowsForInvalidPenalties) {
EXPECT_THROW(FrequencyPenaltyTransform(-3.0), ov::Exception);
EXPECT_THROW(FrequencyPenaltyTransform(+13.0), ov::Exception);
}

TEST(FrequencyPenaltyTransformInitializationTest, ThrowsForInvalidInputIds) {
auto transform = FrequencyPenaltyTransform(1.5);
EXPECT_THROW(transform.apply({ {43.0f, 0} }, std::map<int64_t, size_t>{{1337, 0}} ), ov::Exception);
EXPECT_THROW(transform.apply({ {18.0f, 0} }, std::map<int64_t, size_t>{{0, 1}, {-1, 1}} ), ov::Exception);
}

// ===================
struct PresencePenaltyTransformTestStruct {
float penalty;
std::vector<LogitWithIdx> input_logits;
TokenIds input_ids;
std::vector<LogitWithIdx> expected_output;
};

using PresencePenaltyTransformTest = testing::TestWithParam<PresencePenaltyTransformTestStruct>;

TEST_P(PresencePenaltyTransformTest, TransformResultEqualToReference) {
auto test_struct = GetParam();
auto transform = PresencePenaltyTransform(test_struct.penalty);
auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids);
ASSERT_EQ(test_result.size(), test_struct.expected_output.size());
for (size_t i = 0; i < test_result.size(); i++) {
EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6);
EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second);
}
};


const std::vector<PresencePenaltyTransformTestStruct> PRESENCE_PENALTY_TRANSFORM_TEST_CASES = {
{ // basic case, indices are applied, order is left as-is
0.5f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 1, 0 },
{ {-0.5f, 0}, {1.5f, 1}, {3.0f, 2} }
},
{ // negative scores case
-0.6f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 0, 1, 1 },
{ {-1.6f, 0}, {2.6f, 1}, {3.0f, 2} }
},
{ // repeated tokens in prompt, check that the penalty is only applied once
0.2f,
{ {1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 2, 0, 2 },
{ {0.8f, 0}, {2.0f, 1}, {2.8f, 2} }
},
};

INSTANTIATE_TEST_SUITE_P(VariousInputs,
PresencePenaltyTransformTest,
testing::ValuesIn(PRESENCE_PENALTY_TRANSFORM_TEST_CASES));


TEST(PresencePenaltyTransformInitializationTest, ThrowsForInvalidPenalties) {
EXPECT_THROW(PresencePenaltyTransform(-3.0), ov::Exception);
EXPECT_THROW(PresencePenaltyTransform(+13.0), ov::Exception);
}

TEST(PresencePenaltyTransformInitializationTest, ThrowsForInvalidInputIds) {
auto transform = PresencePenaltyTransform(1.5);
EXPECT_THROW(transform.apply({ {43.0f, 0} }, std::map<int64_t, size_t>{{1337, 0}} ), ov::Exception);
EXPECT_THROW(transform.apply({ {18.0f, 0} }, std::map<int64_t, size_t>{{0, 1}, {-1, 1}} ), ov::Exception);
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ PYBIND11_MODULE(py_continuous_batching, m) {
.def_readwrite("stop_criteria", &GenerationConfig::stop_criteria)
.def_readwrite("num_return_sequences", &GenerationConfig::num_return_sequences)
.def_readwrite("repetition_penalty", &GenerationConfig::repetition_penalty)
.def_readwrite("presence_penalty", &GenerationConfig::presence_penalty)
.def_readwrite("frequence_penalty", &GenerationConfig::frequence_penalty)
.def_readwrite("length_penalty", &GenerationConfig::length_penalty)
.def_readwrite("no_repeat_ngram_size", &GenerationConfig::no_repeat_ngram_size)
.def_readwrite("temperature", &GenerationConfig::temperature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def get_multinomial_temperature_and_repetition_penalty() -> GenerationConfig:
generation_config.repetition_penalty = 2.0
return generation_config

def get_multinomial_temperature_and_frequence_penalty() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.do_sample = True
generation_config.temperature = 0.8
generation_config.frequence_penalty = 0.5
return generation_config

def get_multinomial_temperature_and_presence_penalty() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.do_sample = True
generation_config.temperature = 0.8
generation_config.presence_penalty = 0.1
return generation_config

def get_test_dataset() -> Tuple[List[str], List[GenerationConfig]]:
prompts = [
"What is OpenVINO?",
Expand Down
Loading

0 comments on commit 4f73d36

Please sign in to comment.