From 894e80c4e08cfc0b1f297383a9e16e290ac1ab2b Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 10 Dec 2021 16:10:59 -0500 Subject: [PATCH] remove unused files and bindings --- torchaudio/csrc/CMakeLists.txt | 5 - torchaudio/csrc/decoder/bindings/_decoder.cpp | 19 -- torchaudio/csrc/decoder/bindings/pybind.cpp | 53 ---- .../src/decoder/LexiconFreeDecoder.cpp | 207 --------------- .../decoder/src/decoder/LexiconFreeDecoder.h | 160 ------------ .../src/decoder/LexiconFreeSeq2SeqDecoder.cpp | 179 ------------- .../src/decoder/LexiconFreeSeq2SeqDecoder.h | 141 ---------- .../src/decoder/LexiconSeq2SeqDecoder.cpp | 243 ------------------ .../src/decoder/LexiconSeq2SeqDecoder.h | 165 ------------ .../csrc/decoder/src/decoder/lm/ConvLM.cpp | 239 ----------------- .../csrc/decoder/src/decoder/lm/ConvLM.h | 73 ------ .../csrc/decoder/src/decoder/lm/ZeroLM.cpp | 31 --- .../csrc/decoder/src/decoder/lm/ZeroLM.h | 32 --- torchaudio/prototype/__init__.py | 1 + torchaudio/prototype/ctc_decoder.py | 1 - 15 files changed, 1 insertion(+), 1548 deletions(-) delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.cpp delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.cpp delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.h delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.cpp delete mode 100644 torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.h delete mode 100644 torchaudio/csrc/decoder/src/decoder/lm/ConvLM.cpp delete mode 100644 torchaudio/csrc/decoder/src/decoder/lm/ConvLM.h delete mode 100644 torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.cpp delete mode 100644 torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 774e5e763c3..6196006a0ce 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -134,14 +134,9 @@ if (BUILD_FL_DECODER) set( LIBTORCHAUDIO_DECODER_SOURCES decoder/src/decoder/LexiconDecoder.cpp - decoder/src/decoder/LexiconFreeDecoder.cpp - decoder/src/decoder/LexiconFreeSeq2SeqDecoder.cpp - decoder/src/decoder/LexiconSeq2SeqDecoder.cpp decoder/src/decoder/Trie.cpp decoder/src/decoder/Utils.cpp - decoder/src/decoder/lm/ConvLM.cpp decoder/src/decoder/lm/KenLM.cpp - decoder/src/decoder/lm/ZeroLM.cpp decoder/src/dictionary/String.cpp decoder/src/dictionary/System.cpp decoder/src/dictionary/Dictionary.cpp diff --git a/torchaudio/csrc/decoder/bindings/_decoder.cpp b/torchaudio/csrc/decoder/bindings/_decoder.cpp index 718987fcfa0..0ca784b0822 100644 --- a/torchaudio/csrc/decoder/bindings/_decoder.cpp +++ b/torchaudio/csrc/decoder/bindings/_decoder.cpp @@ -9,9 +9,6 @@ #include #include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h" -#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h" - -// TODO: is this include necessary? #include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h" namespace py = pybind11; @@ -111,20 +108,4 @@ std::vector LexiconDecoder_decode( return decoder.decode(reinterpret_cast(emissions), T, N); } -void LexiconFreeDecoder_decodeStep( - LexiconFreeDecoder& decoder, - uintptr_t emissions, - int T, - int N) { - decoder.decodeStep(reinterpret_cast(emissions), T, N); -} - -std::vector LexiconFreeDecoder_decode( - LexiconFreeDecoder& decoder, - uintptr_t emissions, - int T, - int N) { - return decoder.decode(reinterpret_cast(emissions), T, N); -} - } // namespace diff --git a/torchaudio/csrc/decoder/bindings/pybind.cpp b/torchaudio/csrc/decoder/bindings/pybind.cpp index 52011a5d589..6c6b8199ec8 100644 --- a/torchaudio/csrc/decoder/bindings/pybind.cpp +++ b/torchaudio/csrc/decoder/bindings/pybind.cpp @@ -80,31 +80,6 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { .def_readwrite("log_add", &LexiconDecoderOptions::logAdd) .def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType); - py::class_(m, "LexiconFreeDecoderOptions") - .def( - py::init< - const int, - const int, - const double, - const double, - const double, - const bool, - const CriterionType>(), - "beam_size"_a, - "beam_size_token"_a, - "beam_threshold"_a, - "lm_weight"_a, - "sil_score"_a, - "log_add"_a, - "criterion_type"_a) - .def_readwrite("beam_size", &LexiconFreeDecoderOptions::beamSize) - .def_readwrite("beam_size_token", &LexiconFreeDecoderOptions::beamSizeToken) - .def_readwrite("beam_threshold", &LexiconFreeDecoderOptions::beamThreshold) - .def_readwrite("lm_weight", &LexiconFreeDecoderOptions::lmWeight) - .def_readwrite("sil_score", &LexiconFreeDecoderOptions::silScore) - .def_readwrite("log_add", &LexiconFreeDecoderOptions::logAdd) - .def_readwrite("criterion_type", &LexiconFreeDecoderOptions::criterionType); - py::class_(m, "DecodeResult") .def(py::init(), "length"_a) .def_readwrite("score", &DecodeResult::score) @@ -140,31 +115,6 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { "look_back"_a = 0) .def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis); - py::class_(m, "LexiconFreeDecoder") - .def(py::init< - LexiconFreeDecoderOptions, - const LMPtr, - const int, - const int, - const std::vector&>()) - .def("decode_begin", &LexiconFreeDecoder::decodeBegin) - .def( - "decode_step", - &LexiconFreeDecoder_decodeStep, - "emissions"_a, - "T"_a, - "N"_a) - .def("decode_end", &LexiconFreeDecoder::decodeEnd) - .def("decode", &LexiconFreeDecoder_decode, "emissions"_a, "T"_a, "N"_a) - .def("prune", &LexiconFreeDecoder::prune, "look_back"_a = 0) - .def( - "get_best_hypothesis", - &LexiconFreeDecoder::getBestHypothesis, - "look_back"_a = 0) - .def( - "get_all_final_hypothesis", - &LexiconFreeDecoder::getAllFinalHypothesis); - // FLASHLIGHT DICTIONARY py::class_(m, "Dictionary") @@ -189,8 +139,5 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { "indices"_a); m.def("create_word_dict", &createWordDict, "lexicon"_a); m.def("load_words", &loadWords, "filename"_a, "max_words"_a = -1); - m.def("pack_replabels", &packReplabels, "tokens"_a, "dict"_a, "max_reps"_a); - m.def( - "unpack_replabels", &unpackReplabels, "tokens"_a, "dict"_a, "max_reps"_a); #endif } diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.cpp b/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.cpp deleted file mode 100644 index b71e31eccfb..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h" - -namespace torchaudio { -namespace lib { -namespace text { - -void LexiconFreeDecoder::decodeBegin() { - hyp_.clear(); - hyp_.emplace(0, std::vector()); - - /* note: the lm reset itself with :start() */ - hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_); - nDecodedFrames_ = 0; - nPrunedFrames_ = 0; -} - -void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) { - int startFrame = nDecodedFrames_ - nPrunedFrames_; - // Extend hyp_ buffer - if (hyp_.size() < startFrame + T + 2) { - for (int i = hyp_.size(); i < startFrame + T + 2; i++) { - hyp_.emplace(i, std::vector()); - } - } - - std::vector idx(N); - // Looping over all the frames - for (int t = 0; t < T; t++) { - std::iota(idx.begin(), idx.end(), 0); - if (N > opt_.beamSizeToken) { - std::partial_sort( - idx.begin(), - idx.begin() + opt_.beamSizeToken, - idx.end(), - [&t, &N, &emissions](const size_t& l, const size_t& r) { - return emissions[t * N + l] > emissions[t * N + r]; - }); - } - - candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); - for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) { - const int prevIdx = prevHyp.token; - - for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) { - int n = idx[r]; - double amScore = emissions[t * N + n]; - if (nDecodedFrames_ + t > 0 && - opt_.criterionType == CriterionType::ASG) { - amScore += transitions_[n * N + prevIdx]; - } - double score = prevHyp.score + emissions[t * N + n]; - if (n == sil_) { - score += opt_.silScore; - } - - if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) || - (opt_.criterionType == CriterionType::CTC && n != blank_ && - (n != prevIdx || prevHyp.prevBlank))) { - auto lmStateScorePair = lm_->score(prevHyp.lmState, n); - auto lmScore = lmStateScorePair.second; - - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - score + opt_.lmWeight * lmScore, - lmStateScorePair.first, - &prevHyp, - n, - false, // prevBlank - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - } else if (opt_.criterionType == CriterionType::CTC && n == blank_) { - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - score, - prevHyp.lmState, - &prevHyp, - n, - true, // prevBlank - prevHyp.amScore + amScore, - prevHyp.lmScore); - } else { - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - score, - prevHyp.lmState, - &prevHyp, - n, - false, // prevBlank - prevHyp.amScore + amScore, - prevHyp.lmScore); - } - } - } - - candidatesStore( - candidates_, - candidatePtrs_, - hyp_[startFrame + t + 1], - opt_.beamSize, - candidatesBestScore_ - opt_.beamThreshold, - opt_.logAdd, - false); - updateLMCache(lm_, hyp_[startFrame + t + 1]); - } - nDecodedFrames_ += T; -} - -void LexiconFreeDecoder::decodeEnd() { - candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); - for (const LexiconFreeDecoderState& prevHyp : - hyp_[nDecodedFrames_ - nPrunedFrames_]) { - const LMStatePtr& prevLmState = prevHyp.lmState; - - auto lmStateScorePair = lm_->finish(prevLmState); - auto lmScore = lmStateScorePair.second; - - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + opt_.lmWeight * lmScore, - lmStateScorePair.first, - &prevHyp, - sil_, - false, // prevBlank - prevHyp.amScore, - prevHyp.lmScore + lmScore); - } - - candidatesStore( - candidates_, - candidatePtrs_, - hyp_[nDecodedFrames_ - nPrunedFrames_ + 1], - opt_.beamSize, - candidatesBestScore_ - opt_.beamThreshold, - opt_.logAdd, - true); - ++nDecodedFrames_; -} - -std::vector LexiconFreeDecoder::getAllFinalHypothesis() const { - int finalFrame = nDecodedFrames_ - nPrunedFrames_; - return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame); -} - -DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const { - int finalFrame = nDecodedFrames_ - nPrunedFrames_; - const LexiconFreeDecoderState* bestNode = - findBestAncestor(hyp_.find(finalFrame)->second, lookBack); - - return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack); -} - -int LexiconFreeDecoder::nHypothesis() const { - int finalFrame = nDecodedFrames_ - nPrunedFrames_; - return hyp_.find(finalFrame)->second.size(); -} - -int LexiconFreeDecoder::nDecodedFramesInBuffer() const { - return nDecodedFrames_ - nPrunedFrames_ + 1; -} - -void LexiconFreeDecoder::prune(int lookBack) { - if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) { - return; // Not enough decoded frames to prune - } - - /* (1) Find the last emitted word in the best path */ - int finalFrame = nDecodedFrames_ - nPrunedFrames_; - const LexiconFreeDecoderState* bestNode = - findBestAncestor(hyp_.find(finalFrame)->second, lookBack); - if (!bestNode) { - return; // Not enough decoded frames to prune - } - - int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack; - if (startFrame < 1) { - return; // Not enough decoded frames to prune - } - - /* (2) Move things from back of hyp_ to front and normalize scores */ - pruneAndNormalize(hyp_, startFrame, lookBack); - - nPrunedFrames_ = nDecodedFrames_ - lookBack; -} -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h b/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h deleted file mode 100644 index 5acc7f1007d..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -#include "torchaudio/csrc/decoder/src/decoder/Decoder.h" -#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h" - -namespace torchaudio { -namespace lib { -namespace text { - -struct LexiconFreeDecoderOptions { - int beamSize; // Maximum number of hypothesis we hold after each step - int beamSizeToken; // Maximum number of tokens we consider at each step - double beamThreshold; // Threshold to prune hypothesis - double lmWeight; // Weight of lm - double silScore; // Silence insertion score - bool logAdd; - CriterionType criterionType; // CTC or ASG -}; - -/** - * LexiconFreeDecoderState stores information for each hypothesis in the beam. - */ -struct LexiconFreeDecoderState { - double score; // Accumulated total score so far - LMStatePtr lmState; // Language model state - const LexiconFreeDecoderState* parent; // Parent hypothesis - int token; // Label of token - bool prevBlank; // If previous hypothesis is blank (for CTC only) - - double amScore; // Accumulated AM score so far - double lmScore; // Accumulated LM score so far - - LexiconFreeDecoderState( - const double score, - const LMStatePtr& lmState, - const LexiconFreeDecoderState* parent, - const int token, - const bool prevBlank = false, - const double amScore = 0, - const double lmScore = 0) - : score(score), - lmState(lmState), - parent(parent), - token(token), - prevBlank(prevBlank), - amScore(amScore), - lmScore(lmScore) {} - - LexiconFreeDecoderState() - : score(0), - lmState(nullptr), - parent(nullptr), - token(-1), - prevBlank(false), - amScore(0.), - lmScore(0.) {} - - int compareNoScoreStates(const LexiconFreeDecoderState* node) const { - int lmCmp = lmState->compare(node->lmState); - if (lmCmp != 0) { - return lmCmp > 0 ? 1 : -1; - } else if (token != node->token) { - return token > node->token ? 1 : -1; - } else if (prevBlank != node->prevBlank) { - return prevBlank > node->prevBlank ? 1 : -1; - } - return 0; - } - - int getWord() const { - return -1; - } - - bool isComplete() const { - return true; - } -}; - -/** - * Decoder implements a beam seach decoder that finds the word transcription - * W maximizing: - * - * AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = }| - * - * where P_{lm}(W) is the language model score, pi_i is the value for the i-th - * frame in the path leading to W and AM(W) is the (unnormalized) acoustic model - * score of the transcription W. We are allowed to generate words from all the - * possible combination of tokens. - */ -class LexiconFreeDecoder : public Decoder { - public: - LexiconFreeDecoder( - LexiconFreeDecoderOptions opt, - const LMPtr& lm, - const int sil, - const int blank, - const std::vector& transitions) - : opt_(std::move(opt)), - lm_(lm), - transitions_(transitions), - sil_(sil), - blank_(blank) {} - - void decodeBegin() override; - - void decodeStep(const float* emissions, int T, int N) override; - - void decodeEnd() override; - - int nHypothesis() const; - - void prune(int lookBack = 0) override; - - int nDecodedFramesInBuffer() const override; - - DecodeResult getBestHypothesis(int lookBack = 0) const override; - - std::vector getAllFinalHypothesis() const override; - - protected: - LexiconFreeDecoderOptions opt_; - LMPtr lm_; - std::vector transitions_; - - // All the hypothesis new candidates (can be larger than beamsize) proposed - // based on the ones from previous frame - std::vector candidates_; - - // This vector is designed for efficient sorting and merging the candidates_, - // so instead of moving around objects, we only need to sort pointers - std::vector candidatePtrs_; - - // Best candidate score of current frame - double candidatesBestScore_; - - // Index of silence label - int sil_; - - // Index of blank label (for CTC) - int blank_; - - // Vector of hypothesis for all the frames so far - std::unordered_map> hyp_; - - // These 2 variables are used for online decoding, for hypothesis pruning - int nDecodedFrames_; // Total number of decoded frames. - int nPrunedFrames_; // Total number of pruned frames from hyp_. -}; -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.cpp b/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.cpp deleted file mode 100644 index d323af35732..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.cpp +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.h" - -namespace torchaudio { -namespace lib { -namespace text { - -void LexiconFreeSeq2SeqDecoder::decodeStep( - const float* emissions, - int T, - int N) { - // Extend hyp_ buffer - if (hyp_.size() < maxOutputLength_ + 2) { - for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) { - hyp_.emplace(i, std::vector()); - } - } - - // Start from here. - hyp_[0].clear(); - hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr); - - // Decode frame by frame - int t = 0; - for (; t < maxOutputLength_; t++) { - candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); - - // Batch forwarding - rawY_.clear(); - rawPrevStates_.clear(); - for (const LexiconFreeSeq2SeqDecoderState& prevHyp : hyp_[t]) { - const AMStatePtr& prevState = prevHyp.amState; - if (prevHyp.token == eos_) { - continue; - } - rawY_.push_back(prevHyp.token); - rawPrevStates_.push_back(prevState); - } - if (rawY_.size() == 0) { - break; - } - - std::vector> amScores; - std::vector outStates; - - std::tie(amScores, outStates) = - amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t); - - std::vector idx(amScores.back().size()); - - // Generate new hypothesis - for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) { - const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; - // Change nothing for completed hypothesis - if (prevHyp.token == eos_) { - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score, - prevHyp.lmState, - &prevHyp, - eos_, - nullptr, - prevHyp.amScore, - prevHyp.lmScore); - continue; - } - - const AMStatePtr& outState = outStates[validHypo]; - if (!outState) { - validHypo++; - continue; - } - - std::iota(idx.begin(), idx.end(), 0); - if (amScores[validHypo].size() > opt_.beamSizeToken) { - std::partial_sort( - idx.begin(), - idx.begin() + opt_.beamSizeToken, - idx.end(), - [&amScores, &validHypo](const size_t& l, const size_t& r) { - return amScores[validHypo][l] > amScores[validHypo][r]; - }); - } - - for (int r = 0; - r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken); - r++) { - int n = idx[r]; - double amScore = amScores[validHypo][n]; - - if (n == eos_) { /* (1) Try eos */ - auto lmStateScorePair = lm_->finish(prevHyp.lmState); - auto lmScore = lmStateScorePair.second; - - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore, - lmStateScorePair.first, - &prevHyp, - n, - nullptr, - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - } else { /* (2) Try normal token */ - auto lmStateScorePair = lm_->score(prevHyp.lmState, n); - auto lmScore = lmStateScorePair.second; - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + amScore + opt_.lmWeight * lmScore, - lmStateScorePair.first, - &prevHyp, - n, - outState, - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - } - } - validHypo++; - } - candidatesStore( - candidates_, - candidatePtrs_, - hyp_[t + 1], - opt_.beamSize, - candidatesBestScore_ - opt_.beamThreshold, - opt_.logAdd, - true); - updateLMCache(lm_, hyp_[t + 1]); - } // End of decoding - - while (t > 0 && hyp_[t].empty()) { - --t; - } - hyp_[maxOutputLength_ + 1].resize(hyp_[t].size()); - for (int i = 0; i < hyp_[t].size(); i++) { - hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]); - } -} - -std::vector LexiconFreeSeq2SeqDecoder::getAllFinalHypothesis() - const { - return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size()); -} - -DecodeResult LexiconFreeSeq2SeqDecoder::getBestHypothesis( - int /* unused */) const { - return getHypothesis( - hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size()); -} - -void LexiconFreeSeq2SeqDecoder::prune(int /* unused */) { - return; -} - -int LexiconFreeSeq2SeqDecoder::nDecodedFramesInBuffer() const { - /* unused function */ - return -1; -} -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.h b/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.h deleted file mode 100644 index 4bb0829a4ff..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconFreeSeq2SeqDecoder.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/Decoder.h" -#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h" - -namespace torchaudio { -namespace lib { -namespace text { - -using AMStatePtr = std::shared_ptr; -using AMUpdateFunc = std::function< - std::pair>, std::vector>( - const float*, - const int, - const int, - const std::vector&, - const std::vector&, - int&)>; - -struct LexiconFreeSeq2SeqDecoderOptions { - int beamSize; // Maximum number of hypothesis we hold after each step - int beamSizeToken; // Maximum number of tokens we consider at each step - double beamThreshold; // Threshold to prune hypothesis - double lmWeight; // Weight of lm - double eosScore; // Score for inserting an EOS - bool logAdd; // If or not use logadd when merging hypothesis -}; - -/** - * LexiconFreeSeq2SeqDecoderState stores information for each hypothesis in the - * beam. - */ -struct LexiconFreeSeq2SeqDecoderState { - double score; // Accumulated total score so far - LMStatePtr lmState; // Language model state - const LexiconFreeSeq2SeqDecoderState* parent; // Parent hypothesis - int token; // Label of token - AMStatePtr amState; // Acoustic model state - - double amScore; // Accumulated AM score so far - double lmScore; // Accumulated LM score so far - - LexiconFreeSeq2SeqDecoderState( - const double score, - const LMStatePtr& lmState, - const LexiconFreeSeq2SeqDecoderState* parent, - const int token, - const AMStatePtr& amState = nullptr, - const double amScore = 0, - const double lmScore = 0) - : score(score), - lmState(lmState), - parent(parent), - token(token), - amState(amState), - amScore(amScore), - lmScore(lmScore) {} - - LexiconFreeSeq2SeqDecoderState() - : score(0), - lmState(nullptr), - parent(nullptr), - token(-1), - amState(nullptr), - amScore(0.), - lmScore(0.) {} - - int compareNoScoreStates(const LexiconFreeSeq2SeqDecoderState* node) const { - return lmState->compare(node->lmState); - } - - int getWord() const { - return -1; - } -}; - -/** - * Decoder implements a beam seach decoder that finds the token transcription - * W maximizing: - * - * AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS| - * - * where P_{lm}(W) is the language model score. The sequence of tokens is not - * constrained by a lexicon, and thus the language model must operate at - * token-level. - * - * TODO: Doesn't support online decoding now. - * - */ -class LexiconFreeSeq2SeqDecoder : public Decoder { - public: - LexiconFreeSeq2SeqDecoder( - LexiconFreeSeq2SeqDecoderOptions opt, - const LMPtr& lm, - const int eos, - AMUpdateFunc amUpdateFunc, - const int maxOutputLength) - : opt_(std::move(opt)), - lm_(lm), - eos_(eos), - amUpdateFunc_(amUpdateFunc), - maxOutputLength_(maxOutputLength) {} - - void decodeStep(const float* emissions, int T, int N) override; - - void prune(int lookBack = 0) override; - - int nDecodedFramesInBuffer() const override; - - DecodeResult getBestHypothesis(int lookBack = 0) const override; - - std::vector getAllFinalHypothesis() const override; - - protected: - LexiconFreeSeq2SeqDecoderOptions opt_; - LMPtr lm_; - int eos_; - AMUpdateFunc amUpdateFunc_; - std::vector rawY_; - std::vector rawPrevStates_; - int maxOutputLength_; - - std::vector candidates_; - std::vector candidatePtrs_; - double candidatesBestScore_; - - std::unordered_map> hyp_; -}; -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.cpp b/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.cpp deleted file mode 100644 index df400693eb7..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.cpp +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.h" - -namespace torchaudio { -namespace lib { -namespace text { - -void LexiconSeq2SeqDecoder::decodeStep(const float* emissions, int T, int N) { - // Extend hyp_ buffer - if (hyp_.size() < maxOutputLength_ + 2) { - for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) { - hyp_.emplace(i, std::vector()); - } - } - - // Start from here. - hyp_[0].clear(); - hyp_[0].emplace_back( - 0.0, lm_->start(0), lexicon_->getRoot(), nullptr, -1, -1, nullptr); - - auto compare = [](const LexiconSeq2SeqDecoderState& n1, - const LexiconSeq2SeqDecoderState& n2) { - return n1.score > n2.score; - }; - - // Decode frame by frame - int t = 0; - for (; t < maxOutputLength_; t++) { - candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); - - // Batch forwarding - rawY_.clear(); - rawPrevStates_.clear(); - for (const LexiconSeq2SeqDecoderState& prevHyp : hyp_[t]) { - const AMStatePtr& prevState = prevHyp.amState; - if (prevHyp.token == eos_) { - continue; - } - rawY_.push_back(prevHyp.token); - rawPrevStates_.push_back(prevState); - } - if (rawY_.size() == 0) { - break; - } - - std::vector> amScores; - std::vector outStates; - - std::tie(amScores, outStates) = - amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t); - - std::vector idx(amScores.back().size()); - - // Generate new hypothesis - for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) { - const LexiconSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; - // Change nothing for completed hypothesis - if (prevHyp.token == eos_) { - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score, - prevHyp.lmState, - prevHyp.lex, - &prevHyp, - eos_, - -1, - nullptr, - prevHyp.amScore, - prevHyp.lmScore); - continue; - } - - const AMStatePtr& outState = outStates[validHypo]; - if (!outState) { - validHypo++; - continue; - } - - const TrieNode* prevLex = prevHyp.lex; - const float lexMaxScore = - prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore; - - std::iota(idx.begin(), idx.end(), 0); - if (amScores[validHypo].size() > opt_.beamSizeToken) { - std::partial_sort( - idx.begin(), - idx.begin() + opt_.beamSizeToken, - idx.end(), - [&amScores, &validHypo](const size_t& l, const size_t& r) { - return amScores[validHypo][l] > amScores[validHypo][r]; - }); - } - - for (int r = 0; - r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken); - r++) { - int n = idx[r]; - double amScore = amScores[validHypo][n]; - - /* (1) Try eos */ - if (n == eos_ && (prevLex == lexicon_->getRoot())) { - auto lmStateScorePair = lm_->finish(prevHyp.lmState); - LMStatePtr lmState = lmStateScorePair.first; - double lmScore; - if (isLmToken_) { - lmScore = lmStateScorePair.second; - } else { - lmScore = lmStateScorePair.second - lexMaxScore; - } - - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore, - lmState, - lexicon_->getRoot(), - &prevHyp, - n, - -1, - nullptr, - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - } - - /* (2) Try normal token */ - if (n != eos_) { - auto searchLex = prevLex->children.find(n); - if (searchLex != prevLex->children.end()) { - auto lex = searchLex->second; - LMStatePtr lmState; - double lmScore; - if (isLmToken_) { - auto lmStateScorePair = lm_->score(prevHyp.lmState, n); - lmState = lmStateScorePair.first; - lmScore = lmStateScorePair.second; - } else { - // smearing - lmState = prevHyp.lmState; - lmScore = lex->maxScore - lexMaxScore; - } - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + amScore + opt_.lmWeight * lmScore, - lmState, - lex.get(), - &prevHyp, - n, - -1, - outState, - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - - // If we got a true word - if (lex->labels.size() > 0) { - for (auto word : lex->labels) { - if (!isLmToken_) { - auto lmStateScorePair = lm_->score(prevHyp.lmState, word); - lmState = lmStateScorePair.first; - lmScore = lmStateScorePair.second - lexMaxScore; - } - candidatesAdd( - candidates_, - candidatesBestScore_, - opt_.beamThreshold, - prevHyp.score + amScore + opt_.wordScore + - opt_.lmWeight * lmScore, - lmState, - lexicon_->getRoot(), - &prevHyp, - n, - word, - outState, - prevHyp.amScore + amScore, - prevHyp.lmScore + lmScore); - if (isLmToken_) { - break; - } - } - } - } - } - } - validHypo++; - } - candidatesStore( - candidates_, - candidatePtrs_, - hyp_[t + 1], - opt_.beamSize, - candidatesBestScore_ - opt_.beamThreshold, - opt_.logAdd, - true); - updateLMCache(lm_, hyp_[t + 1]); - } // End of decoding - - while (t > 0 && hyp_[t].empty()) { - --t; - } - hyp_[maxOutputLength_ + 1].resize(hyp_[t].size()); - for (int i = 0; i < hyp_[t].size(); i++) { - hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]); - } -} - -std::vector LexiconSeq2SeqDecoder::getAllFinalHypothesis() const { - return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size()); -} - -DecodeResult LexiconSeq2SeqDecoder::getBestHypothesis(int /* unused */) const { - return getHypothesis( - hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size()); -} - -void LexiconSeq2SeqDecoder::prune(int /* unused */) { - return; -} - -int LexiconSeq2SeqDecoder::nDecodedFramesInBuffer() const { - /* unused function */ - return -1; -} -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.h b/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.h deleted file mode 100644 index 63e8fa7d54d..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/LexiconSeq2SeqDecoder.h +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/Decoder.h" -#include "torchaudio/csrc/decoder/src/decoder/Trie.h" -#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h" - -namespace torchaudio { -namespace lib { -namespace text { - -using AMStatePtr = std::shared_ptr; -using AMUpdateFunc = std::function< - std::pair>, std::vector>( - const float*, - const int, - const int, - const std::vector&, - const std::vector&, - int&)>; - -struct LexiconSeq2SeqDecoderOptions { - int beamSize; // Maximum number of hypothesis we hold after each step - int beamSizeToken; // Maximum number of tokens we consider at each step - double beamThreshold; // Threshold to prune hypothesis - double lmWeight; // Weight of lm - double wordScore; // Word insertion score - double eosScore; // Score for inserting an EOS - bool logAdd; // If or not use logadd when merging hypothesis -}; - -/** - * LexiconSeq2SeqDecoderState stores information for each hypothesis in the - * beam. - */ -struct LexiconSeq2SeqDecoderState { - double score; // Accumulated total score so far - LMStatePtr lmState; // Language model state - const TrieNode* lex; - const LexiconSeq2SeqDecoderState* parent; // Parent hypothesis - int token; // Label of token - int word; - AMStatePtr amState; // Acoustic model state - - double amScore; // Accumulated AM score so far - double lmScore; // Accumulated LM score so far - - LexiconSeq2SeqDecoderState( - const double score, - const LMStatePtr& lmState, - const TrieNode* lex, - const LexiconSeq2SeqDecoderState* parent, - const int token, - const int word, - const AMStatePtr& amState, - const double amScore = 0, - const double lmScore = 0) - : score(score), - lmState(lmState), - lex(lex), - parent(parent), - token(token), - word(word), - amState(amState), - amScore(amScore), - lmScore(lmScore) {} - - LexiconSeq2SeqDecoderState() - : score(0), - lmState(nullptr), - lex(nullptr), - parent(nullptr), - token(-1), - word(-1), - amState(nullptr), - amScore(0.), - lmScore(0.) {} - - int compareNoScoreStates(const LexiconSeq2SeqDecoderState* node) const { - int lmCmp = lmState->compare(node->lmState); - if (lmCmp != 0) { - return lmCmp > 0 ? 1 : -1; - } else if (lex != node->lex) { - return lex > node->lex ? 1 : -1; - } else if (token != node->token) { - return token > node->token ? 1 : -1; - } - return 0; - } - - int getWord() const { - return word; - } -}; - -/** - * Decoder implements a beam seach decoder that finds the token transcription - * W maximizing: - * - * AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS| - * - * where P_{lm}(W) is the language model score. The transcription W is - * constrained by a lexicon. The language model may operate at word-level - * (isLmToken=false) or token-level (isLmToken=true). - * - * TODO: Doesn't support online decoding now. - * - */ -class LexiconSeq2SeqDecoder : public Decoder { - public: - LexiconSeq2SeqDecoder( - LexiconSeq2SeqDecoderOptions opt, - const TriePtr& lexicon, - const LMPtr& lm, - const int eos, - AMUpdateFunc amUpdateFunc, - const int maxOutputLength, - const bool isLmToken) - : opt_(std::move(opt)), - lm_(lm), - lexicon_(lexicon), - eos_(eos), - amUpdateFunc_(amUpdateFunc), - maxOutputLength_(maxOutputLength), - isLmToken_(isLmToken) {} - - void decodeStep(const float* emissions, int T, int N) override; - - void prune(int lookBack = 0) override; - - int nDecodedFramesInBuffer() const override; - - DecodeResult getBestHypothesis(int lookBack = 0) const override; - - std::vector getAllFinalHypothesis() const override; - - protected: - LexiconSeq2SeqDecoderOptions opt_; - LMPtr lm_; - TriePtr lexicon_; - int eos_; - AMUpdateFunc amUpdateFunc_; - std::vector rawY_; - std::vector rawPrevStates_; - int maxOutputLength_; - bool isLmToken_; - - std::vector candidates_; - std::vector candidatePtrs_; - double candidatesBestScore_; - - std::unordered_map> hyp_; -}; -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.cpp b/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.cpp deleted file mode 100644 index 175f9b42113..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.cpp +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - */ - -#include -#include -#include - -#include "torchaudio/csrc/decoder/src/decoder/lm/ConvLM.h" - -namespace torchaudio { -namespace lib { -namespace text { - -ConvLM::ConvLM( - const GetConvLmScoreFunc& getConvLmScoreFunc, - const std::string& tokenVocabPath, - const Dictionary& usrTknDict, - int lmMemory, - int beamSize, - int historySize) - : lmMemory_(lmMemory), - beamSize_(beamSize), - getConvLmScoreFunc_(getConvLmScoreFunc), - maxHistorySize_(historySize) { - if (historySize < 1) { - throw std::invalid_argument("[ConvLM] History size is too small."); - } - - /* Load token vocabulary */ - // Note: fairseq vocab should start with: - // - 0 - 1, - 2, - 3 - std::cerr << "[ConvLM]: Loading vocabulary from " << tokenVocabPath << "\n"; - vocab_ = Dictionary(tokenVocabPath); - vocab_.setDefaultIndex(vocab_.getIndex(kUnkToken)); - vocabSize_ = vocab_.indexSize(); - std::cerr << "[ConvLM]: vocabulary size of convLM " << vocabSize_ << "\n"; - - /* Create index map */ - usrToLmIdxMap_.resize(usrTknDict.indexSize()); - for (int i = 0; i < usrTknDict.indexSize(); i++) { - auto token = usrTknDict.getEntry(i); - int lmIdx = vocab_.getIndex(token.c_str()); - usrToLmIdxMap_[i] = lmIdx; - } - - /* Refresh cache */ - cacheIndices_.reserve(beamSize_); - cache_.resize(beamSize_, std::vector(vocabSize_)); - slot_.reserve(beamSize_); - batchedTokens_.resize(beamSize_ * maxHistorySize_); -} - -LMStatePtr ConvLM::start(bool startWithNothing) { - cacheIndices_.clear(); - auto outState = std::make_shared(1); - if (!startWithNothing) { - outState->length = 1; - outState->tokens[0] = vocab_.getIndex(kEosToken); - } else { - throw std::invalid_argument( - "[ConvLM] Only support using EOS to start the sentence"); - } - return outState; -} - -std::pair ConvLM::scoreWithLmIdx( - const LMStatePtr& state, - const int tokenIdx) { - auto rawInState = std::static_pointer_cast(state).get(); - int inStateLength = rawInState->length; - std::shared_ptr outState; - - // Prepare output state - if (inStateLength == maxHistorySize_) { - outState = std::make_shared(maxHistorySize_); - std::copy( - rawInState->tokens.begin() + 1, - rawInState->tokens.end(), - outState->tokens.begin()); - outState->tokens[maxHistorySize_ - 1] = tokenIdx; - } else { - outState = std::make_shared(inStateLength + 1); - std::copy( - rawInState->tokens.begin(), - rawInState->tokens.end(), - outState->tokens.begin()); - outState->tokens[inStateLength] = tokenIdx; - } - - // Prepare score - float score = 0; - if (tokenIdx < 0 || tokenIdx >= vocabSize_) { - throw std::out_of_range( - "[ConvLM] Invalid query word: " + std::to_string(tokenIdx)); - } - - if (cacheIndices_.find(rawInState) != cacheIndices_.end()) { - // Cache hit - auto cacheInd = cacheIndices_[rawInState]; - if (cacheInd < 0 || cacheInd >= beamSize_) { - throw std::logic_error( - "[ConvLM] Invalid cache access: " + std::to_string(cacheInd)); - } - score = cache_[cacheInd][tokenIdx]; - } else { - // Cache miss - if (cacheIndices_.size() == beamSize_) { - cacheIndices_.clear(); - } - int newIdx = cacheIndices_.size(); - cacheIndices_[rawInState] = newIdx; - - std::vector lastTokenPositions = {rawInState->length - 1}; - cache_[newIdx] = - getConvLmScoreFunc_(rawInState->tokens, lastTokenPositions, -1, 1); - score = cache_[newIdx][tokenIdx]; - } - if (std::isnan(score) || !std::isfinite(score)) { - throw std::runtime_error( - "[ConvLM] Bad scoring from ConvLM: " + std::to_string(score)); - } - return std::make_pair(std::move(outState), score); -} - -std::pair ConvLM::score( - const LMStatePtr& state, - const int usrTokenIdx) { - if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) { - throw std::out_of_range( - "[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx)); - } - return scoreWithLmIdx(state, usrToLmIdxMap_[usrTokenIdx]); -} - -std::pair ConvLM::finish(const LMStatePtr& state) { - return scoreWithLmIdx(state, vocab_.getIndex(kEosToken)); -} - -void ConvLM::updateCache(std::vector states) { - int longestHistory = -1, nStates = states.size(); - if (nStates > beamSize_) { - throw std::invalid_argument( - "[ConvLM] Cache size too small (consider larger than beam size)."); - } - - // Refresh cache, store LM states that did not changed - slot_.clear(); - slot_.resize(beamSize_, nullptr); - for (const auto& state : states) { - auto rawState = std::static_pointer_cast(state).get(); - if (cacheIndices_.find(rawState) != cacheIndices_.end()) { - slot_[cacheIndices_[rawState]] = rawState; - } else if (rawState->length > longestHistory) { - // prepare intest history only for those which should be predicted - longestHistory = rawState->length; - } - } - cacheIndices_.clear(); - int cacheSize = 0; - for (int i = 0; i < beamSize_; i++) { - if (!slot_[i]) { - continue; - } - cache_[cacheSize] = cache_[i]; - cacheIndices_[slot_[i]] = cacheSize; - ++cacheSize; - } - - // Determine batchsize - if (longestHistory <= 0) { - return; - } - // batchSize * longestHistory = cacheSize; - int maxBatchSize = lmMemory_ / longestHistory; - if (maxBatchSize > nStates) { - maxBatchSize = nStates; - } - - // Run batch forward - int batchStart = 0; - while (batchStart < nStates) { - // Select batch - int nBatchStates = 0; - std::vector lastTokenPositions; - for (int i = batchStart; (nBatchStates < maxBatchSize) && (i < nStates); - i++, batchStart++) { - auto rawState = std::static_pointer_cast(states[i]).get(); - if (cacheIndices_.find(rawState) != cacheIndices_.end()) { - continue; - } - cacheIndices_[rawState] = cacheSize + nBatchStates; - int start = nBatchStates * longestHistory; - - for (int j = 0; j < rawState->length; j++) { - batchedTokens_[start + j] = rawState->tokens[j]; - } - start += rawState->length; - for (int j = 0; j < longestHistory - rawState->length; j++) { - batchedTokens_[start + j] = vocab_.getIndex(kPadToken); - } - lastTokenPositions.push_back(rawState->length - 1); - ++nBatchStates; - } - if (nBatchStates == 0 && batchStart >= nStates) { - // if all states were skipped - break; - } - - // Feed forward - if (nBatchStates < 1 || longestHistory < 1) { - throw std::logic_error( - "[ConvLM] Invalid batch: [" + std::to_string(nBatchStates) + " x " + - std::to_string(longestHistory) + "]"); - } - auto batchedProb = getConvLmScoreFunc_( - batchedTokens_, lastTokenPositions, longestHistory, nBatchStates); - - if (batchedProb.size() != vocabSize_ * nBatchStates) { - throw std::logic_error( - "[ConvLM] Batch X Vocab size " + std::to_string(batchedProb.size()) + - " mismatch with " + std::to_string(vocabSize_ * nBatchStates)); - } - // Place probabilities in cache - for (int i = 0; i < nBatchStates; i++, cacheSize++) { - std::memcpy( - cache_[cacheSize].data(), - batchedProb.data() + vocabSize_ * i, - vocabSize_ * sizeof(float)); - } - } -} -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.h b/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.h deleted file mode 100644 index cb471529602..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/lm/ConvLM.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - */ -#pragma once - -#include - -#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h" -#include "torchaudio/csrc/decoder/src/dictionary/Defines.h" -#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h" - -namespace torchaudio { -namespace lib { -namespace text { - -using GetConvLmScoreFunc = std::function(const std::vector&, const std::vector&, int, int)>; - -struct ConvLMState : LMState { - std::vector tokens; - int length; - - ConvLMState() : length(0) {} - explicit ConvLMState(int size) - : tokens(std::vector(size)), length(size) {} -}; - -class ConvLM : public LM { - public: - ConvLM( - const GetConvLmScoreFunc& getConvLmScoreFunc, - const std::string& tokenVocabPath, - const Dictionary& usrTknDict, - int lmMemory = 10000, - int beamSize = 2500, - int historySize = 49); - - LMStatePtr start(bool startWithNothing) override; - - std::pair score( - const LMStatePtr& state, - const int usrTokenIdx) override; - - std::pair finish(const LMStatePtr& state) override; - - void updateCache(std::vector states) override; - - private: - // This cache is also not thread-safe! - int lmMemory_; - int beamSize_; - std::unordered_map cacheIndices_; - std::vector> cache_; - std::vector slot_; - std::vector batchedTokens_; - - Dictionary vocab_; - GetConvLmScoreFunc getConvLmScoreFunc_; - - int vocabSize_; - int maxHistorySize_; - - std::pair scoreWithLmIdx( - const LMStatePtr& state, - const int tokenIdx); -}; -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.cpp b/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.cpp deleted file mode 100644 index bd720a5f330..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h" - -#include - -namespace torchaudio { -namespace lib { -namespace text { - -LMStatePtr ZeroLM::start(bool /* unused */) { - return std::make_shared(); -} - -std::pair ZeroLM::score( - const LMStatePtr& state /* unused */, - const int usrTokenIdx) { - return std::make_pair(state->child(usrTokenIdx), 0.0); -} - -std::pair ZeroLM::finish(const LMStatePtr& state) { - return std::make_pair(state, 0.0); -} -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h b/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h deleted file mode 100644 index dfe375ee7f0..00000000000 --- a/torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h" - -namespace torchaudio { -namespace lib { -namespace text { - -/** - * ZeroLM is a dummy language model class, which mimics the behavious of a - * uni-gram language model but always returns 0 as score. - */ -class ZeroLM : public LM { - public: - LMStatePtr start(bool startWithNothing) override; - - std::pair score( - const LMStatePtr& state, - const int usrTokenIdx) override; - - std::pair finish(const LMStatePtr& state) override; -}; -} // namespace text -} // namespace lib -} // namespace torchaudio diff --git a/torchaudio/prototype/__init__.py b/torchaudio/prototype/__init__.py index b988be7cc5a..0b90a9cd6d0 100644 --- a/torchaudio/prototype/__init__.py +++ b/torchaudio/prototype/__init__.py @@ -2,6 +2,7 @@ from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model from .rnnt_decoder import Hypothesis, RNNTBeamSearch + __all__ = [ "Emformer", "Hypothesis", diff --git a/torchaudio/prototype/ctc_decoder.py b/torchaudio/prototype/ctc_decoder.py index b1566f674f1..8632fb15c94 100644 --- a/torchaudio/prototype/ctc_decoder.py +++ b/torchaudio/prototype/ctc_decoder.py @@ -9,7 +9,6 @@ torchaudio._extension._load_lib('libtorchaudio_decoder') from torchaudio._torchaudio_decoder import ( CriterionType, - DecodeResult, KenLM, LexiconDecoder, LexiconDecoderOptions,