Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] refactor decoder, asr_model to support more platforms #993

Merged
merged 6 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runtime/core/bin/decoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ int main(int argc, char *argv[]) {
feature_pipeline->set_input_finished();
LOG(INFO) << "num frames " << feature_pipeline->num_frames();

wenet::TorchAsrDecoder decoder(feature_pipeline, decode_resource,
*decode_config);
wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
*decode_config);

int wave_dur =
static_cast<int>(static_cast<float>(wav_reader.num_sample()) /
Expand Down
4 changes: 2 additions & 2 deletions runtime/core/bin/label_checker_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ int main(int argc, char *argv[]) {
feature_pipeline->set_input_finished();
decode_resource->fst = decoding_fst;
LOG(INFO) << "num frames " << feature_pipeline->num_frames();
wenet::TorchAsrDecoder decoder(feature_pipeline, decode_resource,
*decode_config);
wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
*decode_config);
while (true) {
wenet::DecodeState state = decoder.Decode();
if (state == wenet::DecodeState::kEndFeats) {
Expand Down
196 changes: 196 additions & 0 deletions runtime/core/decoder/asr_decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2020 Mobvoi Inc. All Rights Reserved.
// Author: binbinzhang@mobvoi.com (Binbin Zhang)
// di.wu@mobvoi.com (Di Wu)

#include "decoder/asr_decoder.h"

#include <ctype.h>

#include <algorithm>
#include <limits>
#include <utility>

#include "utils/timer.h"

namespace wenet {

AsrDecoder::AsrDecoder(
std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource, const DecodeOptions& opts)
: feature_pipeline_(std::move(feature_pipeline)),
// Make a copy of the model ASR model since we will change the inner
// status of the model
model_(resource->model->Copy()),
post_processor_(resource->post_processor),
symbol_table_(resource->symbol_table),
fst_(resource->fst),
unit_table_(resource->unit_table),
opts_(opts),
ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) {
if (opts_.reverse_weight > 0) {
// Check if model has a right to left decoder
CHECK(model_->is_bidirectional_decoder());
}
if (nullptr == fst_) {
searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts,
resource->context_graph));
} else {
searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts,
resource->context_graph));
}
ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms());
}

void AsrDecoder::Reset() {
start_ = false;
result_.clear();
num_frames_ = 0;
global_frame_offset_ = 0;
model_->Reset();
searcher_->Reset();
feature_pipeline_->Reset();
ctc_endpointer_->Reset();
}

void AsrDecoder::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
start_ = false;
result_.clear();
model_->Reset();
searcher_->Reset();
ctc_endpointer_->Reset();
}


DecodeState AsrDecoder::Decode() { return this->AdvanceDecoding(); }


void AsrDecoder::Rescoring() {
// Do attention rescoring
Timer timer;
AttentionRescoring();
LOG(INFO) << "Rescoring cost latency: " << timer.Elapsed() << "ms.";
}


DecodeState AsrDecoder::AdvanceDecoding() {
DecodeState state = DecodeState::kEndBatch;
const int subsampling_rate = model_->subsampling_rate();
const int right_context = model_->right_context();
const int feature_dim = feature_pipeline_->feature_dim();
model_->set_chunk_size(opts_.chunk_size);
model_->set_num_left_chunks(opts_.num_left_chunks);
int num_requried_frames = model_->num_frames_for_chunk(start_);
std::vector<std::vector<float>> chunk_feats;
// If not okay, that means we reach the end of the input
if (!feature_pipeline_->Read(num_requried_frames, &chunk_feats)) {
state = DecodeState::kEndFeats;
}

num_frames_ += chunk_feats.size();
LOG(INFO) << "Required " << num_requried_frames << " get "
<< chunk_feats.size();
Timer timer;
std::vector<std::vector<float>> ctc_log_probs;
model_->ForwardEncoder(chunk_feats, &ctc_log_probs);
int forward_time = timer.Elapsed();
timer.Reset();
searcher_->Search(ctc_log_probs);
int search_time = timer.Elapsed();
VLOG(3) << "forward takes " << forward_time << " ms, search takes "
<< search_time << " ms";
UpdateResult();

if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) {
LOG(INFO) << "Endpoint is detected at " << num_frames_;
state = DecodeState::kEndpoint;
}

start_ = true;
return state;
}


void AsrDecoder::UpdateResult(bool finish) {
const auto& hypotheses = searcher_->Outputs();
const auto& inputs = searcher_->Inputs();
const auto& likelihood = searcher_->Likelihood();
const auto& times = searcher_->Times();
result_.clear();

CHECK_EQ(hypotheses.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];

DecodeResult path;
path.score = likelihood[i];
int offset = global_frame_offset_ * feature_frame_shift_in_ms();
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (searcher_->Type() == kWfstBeamSearch) {
path.sentence += (' ' + word);
} else {
path.sentence += (word);
}
}

// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
const std::vector<int>& input = inputs[i];
const std::vector<int>& time_stamp = times[i];
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
int start = j > 0 ? ((time_stamp[j - 1] + time_stamp[j]) / 2 *
frame_shift_in_ms())
: 0;
int end = j < input.size() - 1 ?
((time_stamp[j] + time_stamp[j + 1]) / 2 * frame_shift_in_ms()) :
model_->offset() * frame_shift_in_ms();
WordPiece word_piece(word, offset + start, offset + end);
path.word_pieces.emplace_back(word_piece);
}
}
path.sentence = post_processor_->Process(path.sentence, finish);
result_.emplace_back(path);
}

if (DecodedSomething()) {
VLOG(1) << "Partial CTC result " << result_[0].sentence;
}
}

void AsrDecoder::AttentionRescoring() {
searcher_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.rescoring_weight) {
return;
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = searcher_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}

std::vector<float> rescoring_score;
model_->AttentionRescoring(hypotheses, opts_.reverse_weight,
&rescoring_score);

// Combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
result_[i].score = opts_.rescoring_weight * rescoring_score[i] +
opts_.ctc_weight * result_[i].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
}

} // namespace wenet
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Author: binbinzhang@mobvoi.com (Binbin Zhang)
// di.wu@mobvoi.com (Di Wu)

#ifndef DECODER_TORCH_ASR_DECODER_H_
#define DECODER_TORCH_ASR_DECODER_H_
#ifndef DECODER_ASR_DECODER_H_
#define DECODER_ASR_DECODER_H_

#include <memory>
#include <string>
Expand All @@ -12,22 +12,19 @@

#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "torch/script.h"
#include "torch/torch.h"

#include "decoder/asr_model.h"
#include "decoder/context_graph.h"
#include "decoder/ctc_endpoint.h"
#include "decoder/ctc_prefix_beam_search.h"
#include "decoder/ctc_wfst_beam_search.h"
#include "decoder/torch_asr_model.h"
#include "decoder/search_interface.h"
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/utils.h"

namespace wenet {

using TorchModule = torch::jit::script::Module;

struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
Expand Down Expand Up @@ -79,7 +76,7 @@ enum DecodeState {
// DecodeResource is thread safe, which can be shared for multiple
// decoding threads
struct DecodeResource {
std::shared_ptr<TorchAsrModel> model = nullptr;
std::shared_ptr<AsrModel> model = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table = nullptr;
std::shared_ptr<fst::Fst<fst::StdArc>> fst = nullptr;
std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
Expand All @@ -88,11 +85,11 @@ struct DecodeResource {
};

// Torch ASR decoder
class TorchAsrDecoder {
class AsrDecoder {
public:
TorchAsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts);
AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts);

DecodeState Decode();
void Rescoring();
Expand Down Expand Up @@ -122,12 +119,10 @@ class TorchAsrDecoder {
DecodeState AdvanceDecoding();
void AttentionRescoring();

float AttentionDecoderScore(const torch::Tensor& prob,
const std::vector<int>& hyp, int eos);
void UpdateResult(bool finish = false);

std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<TorchAsrModel> model_;
std::shared_ptr<AsrModel> model_;
std::shared_ptr<PostProcessor> post_processor_;

std::shared_ptr<fst::Fst<fst::StdArc>> fst_ = nullptr;
Expand All @@ -137,15 +132,7 @@ class TorchAsrDecoder {
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
const DecodeOptions& opts_;
// cache feature
std::vector<std::vector<float>> cached_feature_;
bool start_ = false;

torch::jit::IValue subsampling_cache_;
// transformer/conformer encoder layers output cache
torch::jit::IValue elayers_output_cache_;
torch::jit::IValue conformer_cnn_cache_;
std::vector<torch::Tensor> encoder_outs_;
int offset_ = 0; // offset
robin1001 marked this conversation as resolved.
Show resolved Hide resolved
// For continuous decoding
int num_frames_ = 0;
int global_frame_offset_ = 0;
Expand All @@ -157,9 +144,9 @@ class TorchAsrDecoder {
std::vector<DecodeResult> result_;

public:
WENET_DISALLOW_COPY_AND_ASSIGN(TorchAsrDecoder);
WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder);
};

} // namespace wenet

#endif // DECODER_TORCH_ASR_DECODER_H_
#endif // DECODER_ASR_DECODER_H_
62 changes: 62 additions & 0 deletions runtime/core/decoder/asr_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)

#include "decoder/asr_model.h"

#include <memory>
#include <utility>

#include "torch/script.h"
#include "torch/torch.h"

namespace wenet {

int AsrModel::num_frames_for_chunk(bool start) const {
int num_requried_frames = 0;
if (chunk_size_ > 0) {
if (!start) { // First batch
int context = right_context_ + 1; // Add current frame
num_requried_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
} else {
num_requried_frames = chunk_size_ * subsampling_rate_;
}
} else {
num_requried_frames = std::numeric_limits<int>::max();
}
return num_requried_frames;
}


void AsrModel::CacheFeature(
const std::vector<std::vector<float>>& chunk_feats) {
// Cache feature for next chunk
const int cached_feature_size = 1 + right_context_ - subsampling_rate_;
if (chunk_feats.size() >= cached_feature_size) {
// TODO(Binbin Zhang): Only deal the case when
// chunk_feats.size() > cached_feature_size here, and it's consistent
// with our current model, refine it later if we have new model or
// new requirements
cached_feature_.resize(cached_feature_size);
for (int i = 0; i < cached_feature_size; ++i) {
cached_feature_[i] =
chunk_feats[chunk_feats.size() - cached_feature_size + i];
}
}
}
robin1001 marked this conversation as resolved.
Show resolved Hide resolved


void AsrModel::ForwardEncoder(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>> *ctc_prob) {
ctc_prob->clear();
int num_frames = cached_feature_.size() + chunk_feats.size();
if (num_frames > right_context_ + 1) {
this->ForwardEncoderFunc(chunk_feats, ctc_prob);
this->CacheFeature(chunk_feats);
}
}
robin1001 marked this conversation as resolved.
Show resolved Hide resolved

} // namespace wenet



Loading