Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: implement chunk streamer for long-form audio processing #1148

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,34 @@
namespace ov {
namespace genai {

/**
/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
*
* @param m_tokenizer tokenizer
*/
*/
class StreamerBase {
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(int64_t token) = 0;

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual ~StreamerBase() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to move dtor definition to .cpp file and export this class

it will help with RTTI issue on some platforms

};

/**
* @brief base class for chunk streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class ChunkStreamerBase : public StreamerBase {
public:
/// @brief put is called every time new token chunk is generated,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put_chunk(std::vector<int64_t> tokens) = 0;
};

} // namespace genai
} // namespace ov
9 changes: 8 additions & 1 deletion src/cpp/include/openvino/genai/whisper_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ using OptionalWhisperGenerationConfig = std::optional<WhisperGenerationConfig>;

using RawSpeechInput = std::vector<float>;

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using ChunkStreamerVariant =
std::variant<std::function<bool(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;

struct WhisperDecodedResultChunk {
// start of chunk in seconds
float start_ts;
Expand Down Expand Up @@ -83,7 +87,7 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
*/
WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config = std::nullopt,
StreamerVariant streamer = std::monostate());
ChunkStreamerVariant streamer = std::monostate());

/**
* @brief High level generate that receives raw speech as a vector of floats and returns decoded output.
Expand All @@ -105,4 +109,7 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
WhisperGenerationConfig get_generation_config() const;
void set_generation_config(const WhisperGenerationConfig& config);
};

OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> chunk_streamer(ChunkStreamerVariant func);
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> whisper_generation_config(const WhisperGenerationConfig& config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can overload existing function like streamer and generation_config instead of introducing Whisper specific.

Example:

OPENVINO_GENAI_EXPORTS
std::pair<std::string, ov::Any> generation_config(const ImageGenerationConfig& generation_config);

} // namespace ov::genai
25 changes: 23 additions & 2 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ bool TextCallbackStreamer::put(int64_t token) {
return on_finalized_subword_callback(res.str());
}

constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
// MSVC with /utf-8 fails to compile � directly with newline in string literal error.
constexpr char replacement[] = "\xef\xbf\xbd";
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
Expand All @@ -41,13 +42,33 @@ void TextCallbackStreamer::end() {
std::stringstream res;
std::string text = m_tokenizer.decode(m_tokens_cache);
if (text.size() <= print_len)
return ;
return;
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
m_tokens_cache.clear();
print_len = 0;
on_finalized_subword_callback(res.str());
return;
}

bool ChunkTextCallbackStreamer::put(int64_t token) {
return ov::genai::TextCallbackStreamer::put(token);
}

bool ChunkTextCallbackStreamer::put_chunk(std::vector<int64_t> tokens) {
if (tokens.empty()) {
return false;
}

if (tokens.size() > 1) {
m_tokens_cache.insert(m_tokens_cache.end(), tokens.begin(), tokens.end() - 1);
}

return ov::genai::TextCallbackStreamer::put(tokens.back());
}

void ChunkTextCallbackStreamer::end() {
ov::genai::TextCallbackStreamer::end();
}

} // namespace genai
} // namespace ov
21 changes: 17 additions & 4 deletions src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,32 @@
namespace ov {
namespace genai {

class TextCallbackStreamer: public StreamerBase {
class TextCallbackStreamer : public StreamerBase {
public:
bool put(int64_t token) override;
void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback);

std::function<bool(std::string)> on_finalized_subword_callback = [](std::string words)->bool { return false; };
private:

std::function<bool(std::string)> on_finalized_subword_callback = [](std::string words) -> bool {
return false;
};

protected:
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
};

class ChunkTextCallbackStreamer : public TextCallbackStreamer, public ChunkStreamerBase {
public:
bool put(int64_t token) override;
bool put_chunk(std::vector<int64_t> tokens) override;
void end() override;

ChunkTextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback)
: TextCallbackStreamer(tokenizer, callback){};
};

} // namespace genai
} // namespace ov
14 changes: 14 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) {
return streamer;
}

ov::genai::ChunkStreamerVariant get_chunk_streamer_from_map(const ov::AnyMap& config_map) {
ov::genai::ChunkStreamerVariant streamer = std::monostate();

if (config_map.count(STREAMER_ARG_NAME)) {
auto any_val = config_map.at(STREAMER_ARG_NAME);
if (any_val.is<std::shared_ptr<ov::genai::ChunkStreamerBase>>()) {
streamer = any_val.as<std::shared_ptr<ov::genai::ChunkStreamerBase>>();
} else if (any_val.is<std::function<bool(std::string)>>()) {
streamer = any_val.as<std::function<bool(std::string)>>();
}
}
return streamer;
}

ov::genai::OptionalGenerationConfig get_config_from_map(const ov::AnyMap& config_map) {
if (config_map.count(CONFIG_ARG_NAME))
return config_map.at(CONFIG_ARG_NAME).as<ov::genai::GenerationConfig>();
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "openvino/genai/llm_pipeline.hpp"
#include "openvino/genai/whisper_pipeline.hpp"
#include "openvino/runtime/core.hpp"

#include "visual_language/processor_config.hpp"
Expand Down Expand Up @@ -50,6 +51,7 @@ Config from_config_json_if_exists(const std::filesystem::path& models_path, cons
}

ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map);
ov::genai::ChunkStreamerVariant get_chunk_streamer_from_map(const ov::AnyMap& config_map);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's whisper specific entity. Maybe we can move it to whisper files? the same in other places like py_utils.hpp

these utils are supposed to be generic ones


ov::genai::OptionalGenerationConfig get_config_from_map(const ov::AnyMap& config_map);

Expand Down
22 changes: 9 additions & 13 deletions src/cpp/src/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include "logit_processor.hpp"
#include "openvino/genai/perf_metrics.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/genai/whisper_pipeline.hpp"
#include "timestamps.hpp"
Expand Down Expand Up @@ -222,9 +221,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta

std::vector<int64_t> output_tokens{output_token};

const size_t timestamp_begin = config.no_timestamps_token_id + 1;
bool is_timestamp = output_token >= timestamp_begin;
if (!is_timestamp && streamer && streamer->put(output_token)) {
if (!return_timestamps && streamer && streamer->put({output_token})) {
return {true, output_tokens};
}

Expand All @@ -238,7 +235,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
auto output_token = decode_with_past(encoder_hidden_state,
models.decoder_with_past,
output_tokens.back(),
init_ids.size() + output_tokens.size() - 1,
init_ids.size() + i,
config,
raw_metrics,
return_timestamps,
Expand All @@ -253,9 +250,8 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
}

output_tokens.push_back(output_token);
bool is_timestamp = output_token >= timestamp_begin;

if (!is_timestamp && streamer && streamer->put(output_token)) {
if (!return_timestamps && streamer && streamer->put({output_token})) {
return {true, output_tokens};
}
}
Expand All @@ -273,15 +269,10 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
const RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer) {
OPENVINO_ASSERT(!streamer || !config.return_timestamps, "Streamer is not supported with 'return_timestamps' enabled.");

const std::shared_ptr<ChunkStreamerBase> streamer) {
auto input_features = feature_extractor.extract(raw_speech);

const bool is_shortform = input_features.n_frames <= feature_extractor.nb_max_frames;

OPENVINO_ASSERT(!streamer || is_shortform, "Streamer is not supported for long-form audio processing.");

// long-form audio processing requires timestamps to be enabled
const bool return_timestamps = config.return_timestamps || !is_shortform;

Expand Down Expand Up @@ -344,6 +335,11 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
extracted_segments.non_timestamp_tokens.begin(),
extracted_segments.non_timestamp_tokens.end());

if (streamer && streamer->put_chunk(extracted_segments.non_timestamp_tokens)) {
cancelled = true;
break;
}

segment_offset = extracted_segments.last_offset;
} else {
output_tokens.insert(output_tokens.end(), chunk_output_tokens.begin(), chunk_output_tokens.end());
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/whisper/whisper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
const ov::genai::RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
ov::genai::WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer);
const std::shared_ptr<ChunkStreamerBase> streamer);

} // namespace genai
} // namespace ov
25 changes: 19 additions & 6 deletions src/cpp/src/whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi

WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config,
StreamerVariant streamer) override {
ChunkStreamerVariant streamer) override {
auto start_time = std::chrono::steady_clock::now();
WhisperGenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
config.validate();

std::shared_ptr<StreamerBase> streamer_ptr;
std::shared_ptr<ChunkStreamerBase> streamer_ptr;
if (auto streamer_obj = std::get_if<std::monostate>(&streamer)) {
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
} else if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
}

auto generate_result = ov::genai::whisper_generate(config,
Expand Down Expand Up @@ -114,6 +114,19 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
}
};

std::pair<std::string, Any> chunk_streamer(ChunkStreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<ChunkStreamerBase>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
}
}

std::pair<std::string, Any> whisper_generation_config(const WhisperGenerationConfig& config) {
return {utils::CONFIG_ARG_NAME, Any::make<WhisperGenerationConfig>(config)};
}

} // namespace genai
} // namespace ov

Expand All @@ -132,7 +145,7 @@ ov::genai::WhisperPipeline::WhisperPipeline(const std::filesystem::path& models_

ov::genai::WhisperDecodedResults ov::genai::WhisperPipeline::generate(const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config,
StreamerVariant streamer) {
ChunkStreamerVariant streamer) {
return m_impl->generate(raw_speech_input, generation_config, streamer);
}

Expand All @@ -142,7 +155,7 @@ ov::genai::WhisperDecodedResults ov::genai::WhisperPipeline::generate(const RawS
WhisperGenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config();
config.update_generation_config(config_map);

return m_impl->generate(raw_speech_input, config, utils::get_streamer_from_map(config_map));
return m_impl->generate(raw_speech_input, config, utils::get_chunk_streamer_from_map(config_map));
}

ov::genai::WhisperGenerationConfig ov::genai::WhisperPipeline::get_generation_config() const {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/whisper_pipeline_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class WhisperPipeline::WhisperPipelineImplBase {

virtual WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config,
StreamerVariant streamer) = 0;
ChunkStreamerVariant streamer) = 0;

virtual ~WhisperPipelineImplBase() = default;
};
Expand Down
22 changes: 12 additions & 10 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,11 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
std::vector<int32_t> init_ids,
const size_t max_new_tokens,
const bool return_timestamps,
const std::shared_ptr<ov::genai::StreamerBase> streamer) {
const std::shared_ptr<ov::genai::ChunkStreamerBase> streamer) {
int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, true, return_timestamps);
std::vector<int64_t> output_tokens{output_token};

const size_t timestamp_begin = config.no_timestamps_token_id + 1;
bool is_timestamp = output_token >= timestamp_begin;
if (!is_timestamp && streamer && streamer->put(output_token)) {
if (!return_timestamps && streamer && streamer->put(output_token)) {
return {true, output_tokens};
}

Expand All @@ -307,9 +305,8 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
}

output_tokens.push_back(output_token);
bool is_timestamp = output_token >= timestamp_begin;

if (!is_timestamp && streamer && streamer->put(output_token)) {
if (!return_timestamps && streamer && streamer->put(output_token)) {
return {true, output_tokens};
}
}
Expand Down Expand Up @@ -360,17 +357,17 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config,
StreamerVariant streamer) {
ChunkStreamerVariant streamer) {
WhisperGenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
config.validate();

std::shared_ptr<StreamerBase> streamer_ptr;
std::shared_ptr<ChunkStreamerBase> streamer_ptr;
if (auto streamer_obj = std::get_if<std::monostate>(&streamer)) {
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
} else if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
}

auto input_features = m_feature_extractor.extract(raw_speech_input);
Expand Down Expand Up @@ -428,6 +425,11 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
extracted_segments.non_timestamp_tokens.begin(),
extracted_segments.non_timestamp_tokens.end());

if (streamer_ptr && streamer_ptr->put_chunk(extracted_segments.non_timestamp_tokens)) {
cancelled = true;
break;
}

segment_offset = extracted_segments.last_offset;
} else {
output_tokens.insert(output_tokens.end(), chunk_output_tokens.begin(), chunk_output_tokens.end());
Expand Down
Loading
Loading