Skip to content

Commit

Permalink
[runtime] Configurable blank token idx (#2366)
Browse files Browse the repository at this point in the history
* Pass in blank id

* Fix

* ctc endpointing blank id

---------

Co-authored-by: hzhou245 <hzhou245@bloomberg.net>
  • Loading branch information
zhr1201 and hzhou245 authored Feb 26, 2024
1 parent fbbecfd commit 89ef2e7
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
7 changes: 4 additions & 3 deletions runtime/core/decoder/ctc_wfst_beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void CtcWfstBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
}
// Every time we get the log posterior, we decode it all before return
for (int i = 0; i < logp.size(); i++) {
float blank_score = std::exp(logp[i][0]);
float blank_score = std::exp(logp[i][opts_.blank]);
if (blank_score > opts_.blank_skip_thresh * opts_.blank_scale) {
VLOG(3) << "skipping frame " << num_frames_ << " score " << blank_score;
is_last_frame_blank_ = true;
Expand All @@ -88,7 +88,8 @@ void CtcWfstBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
std::max_element(logp[i].begin(), logp[i].end()) - logp[i].begin();
// Optional, adding one blank frame if we has skipped it in two same
// symbols
if (cur_best != 0 && is_last_frame_blank_ && cur_best == last_best_) {
if (cur_best != opts_.blank && is_last_frame_blank_ &&
cur_best == last_best_) {
decodable_.AcceptLoglikes(last_frame_prob_);
decoder_.AdvanceDecoding(&decodable_, 1);
decoded_frames_mapping_.push_back(num_frames_ - 1);
Expand Down Expand Up @@ -168,7 +169,7 @@ void CtcWfstBeamSearch::ConvertToInputs(const std::vector<int>& alignment,
if (time != nullptr) time->clear();
for (int cur = 0; cur < alignment.size(); ++cur) {
// ignore blank
if (alignment[cur] - 1 == 0) continue;
if (alignment[cur] - 1 == opts_.blank) continue;
// merge continuous same label
if (cur > 0 && alignment[cur] == alignment[cur - 1]) continue;

Expand Down
1 change: 1 addition & 0 deletions runtime/core/decoder/ctc_wfst_beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig {
// search
float blank_skip_thresh = 0.98;
float blank_scale = 1.0;
int blank = 0;
};

class CtcWfstBeamSearch : public SearchInterface {
Expand Down
5 changes: 5 additions & 0 deletions runtime/core/decoder/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
DEFINE_double(beam, 16.0, "beam in ctc wfst search");
DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_int32(blank_id, 0,
"blank token idx for ctc wfst search and ctc prefix beam search");
DEFINE_double(blank_skip_thresh, 1.0,
"blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(blank_scale, 1.0, "blank scale for ctc wfst search");
Expand Down Expand Up @@ -145,13 +147,16 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
decode_config->ctc_wfst_search_opts.blank = FLAGS_blank_id;
decode_config->ctc_wfst_search_opts.blank_skip_thresh =
FLAGS_blank_skip_thresh;
decode_config->ctc_wfst_search_opts.blank_scale = FLAGS_blank_scale;
decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty;
decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.blank = FLAGS_blank_id;
decode_config->ctc_endpoint_config.blank = FLAGS_blank_id;
return decode_config;
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/core/utils/string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ std::string ProcessBlank(const std::string& str, bool lowercase) {
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
std::wstring wsresult = converter.from_bytes(result);
for (auto& c : wsresult) {
c = lowercase ? tolower(c, loc) : toupper(c, loc);
c = lowercase ? tolower(c, loc) : c;
}
result = converter.to_bytes(wsresult);
} catch (std::exception& e) {
Expand Down

0 comments on commit 89ef2e7

Please sign in to comment.