Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Continuous batching] Add finish reason to generation output #725

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ struct EncodedGenerationResult {
GenerationStatus m_status = GenerationStatus::RUNNING;
};

enum class GenerationFinishReason {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this new enum? I mean that GenerationStatus is more generic status, which can be extended to support FINISHED_BY_LENGHT and FINISHED_BY_STOP instead of generic FINISHED.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationStatus is a status for the whole request while GenerationFinishReason is per sequence, so for example when we have beam search or n > 1 we can read generation status for entire request (that would be finished) and get finish reason for every output separately (in case some of the beams hit max_new_tokens and some of them stopped naturally due to EOS token for example).

I also considered using SequenceStatus, but this class is internal and we need a one that's available for the user.

NONE = 0, // Default value, when generation is not yet finished
STOP = 1, // Generation finished naturally, by reaching end of sequence token
LENGTH = 2 // Generation finished by reaching max_new_tokens limit
};

struct GenerationResult {
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
uint64_t m_request_id;
Expand All @@ -49,6 +55,7 @@ struct GenerationResult {
struct GenerationOutput {
std::vector<int64_t> generated_token_ids;
float score;
GenerationFinishReason finish_reason;
};

using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
} else {
partial_result_iter->second.generated_token_ids.push_back(iteration_result.second.generated_token_ids[0]);
partial_result_iter->second.score = iteration_result.second.score;
partial_result_iter->second.finish_reason = iteration_result.second.finish_reason;
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ class GroupBeamSearcher {

// mark current sequence as finished
beam.m_sequence->set_status(SequenceStatus::FINISHED);
// Setting length since this function is used when sequence generated tokens number reaches max_new_tokens
beam.m_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
// we also need to drop add ongoing / forked sequences from scheduler
sampler_output.m_dropped_sequences.push_back(sequence_id);
}
Expand Down Expand Up @@ -432,6 +434,8 @@ void GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutp
Sequence::Ptr forked_sequence = m_sequence_group->fork_sequence(candidate.m_sequence);
// and finish immidiately
forked_sequence->set_status(SequenceStatus::FINISHED);
// Setting length since this function is used when sequence generated eos token
mzegla marked this conversation as resolved.
Show resolved Hide resolved
forked_sequence->set_finish_reason(GenerationFinishReason::STOP);

// TODO: make it more simplier
// currently, we finish sequence and then fork it in current code
Expand Down
21 changes: 19 additions & 2 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Sequence {
uint64_t m_grouped_id;
uint64_t m_id = _get_next_global_sequence_id();
SequenceStatus m_status = SequenceStatus::RUNNING;
GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE;
float m_cumulative_log_prob = 0.0f;

public:
Expand Down Expand Up @@ -91,6 +92,14 @@ class Sequence {
m_status = status;
}

GenerationFinishReason get_finish_reason() const {
return m_finish_reason;
}

void set_finish_reason(GenerationFinishReason finish_reason) {
m_finish_reason = finish_reason;
}

// appends new tokens to a generated part
void append_token(int64_t token_id, float log_prob) {
m_cumulative_log_prob += log_prob;
Expand All @@ -102,6 +111,7 @@ class Sequence {
OPENVINO_ASSERT(m_generated_ids.size());
output.score = get_cumulative_log_probs();
output.generated_token_ids = std::vector<int64_t> {m_generated_ids.back()};
output.finish_reason = get_finish_reason();
return output;
}

Expand Down Expand Up @@ -205,6 +215,13 @@ class SequenceGroup {
running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or EOS token
running_sequence->set_status(SequenceStatus::FINISHED);

if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
} else if (m_sampling_params.max_new_tokens == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

dropped_seq_ids.push_back(running_sequence->get_id());
}
}
Expand Down Expand Up @@ -451,15 +468,15 @@ class SequenceGroup {
for (auto& sequence: m_sequences) {
GenerationOutput output;
output.generated_token_ids = sequence->get_generated_ids();
output.score = sequence->get_beam_search_score(m_sampling_params);
output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs();
output.finish_reason = sequence->get_finish_reason();
outputs.emplace(sequence->get_grouped_id(), output);
}
m_generation_stream->push(outputs);
}

void push_partial_outputs() {
GenerationOutputs outputs;
// TODO: support streamimg for n seqs
for (auto& sequence : m_sequences) {
// todo: check seq.is_finished() to generate without several </s>
// or is it ok to use padding?
Expand Down
Loading