Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support removing bytes equal to 0 from the output in text normalization #56

Merged
merged 3 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR)

project(kaldifst CXX)

set(KALDIFST_VERSION "1.7.7")
set(KALDIFST_VERSION "1.7.8")

if(NOT CMAKE_BUILD_TYPE)
message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
Expand Down
50 changes: 41 additions & 9 deletions kaldifst/csrc/text-normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <string>
#include <utility>

#include "fst/arcsort.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "kaldifst/csrc/table-matcher.h"

Expand Down Expand Up @@ -48,6 +47,44 @@ static fst::StdVectorFst StringToFst(const std::string &text) {
return ans;
}

static std::string FstToString(const fst::StdVectorFst &fst,
bool remove_output_zero) {
std::string ans;

using Weight = typename fst::StdArc::Weight;
using Arc = fst::StdArc;
auto s = fst.Start();
if (s == fst::kNoStateId) {
// this is an empty FST
return "";
}
while (fst.Final(s) == Weight::Zero()) {
fst::ArcIterator<fst::Fst<Arc>> aiter(fst, s);
if (aiter.Done()) {
// not reached final.
return "";
}

const auto &arc = aiter.Value();
if (arc.olabel != 0 || !remove_output_zero) {
ans.push_back(arc.olabel);
}

s = arc.nextstate;
if (s == fst::kNoStateId) {
// Transition to invalid state";
return "";
}

aiter.Next();
if (!aiter.Done()) {
// not a linear FST
return "";
}
}
return ans;
}

TextNormalizer::TextNormalizer(const std::string &rule) {
rule_ = std::unique_ptr<fst::StdConstFst>(
CastOrConvertToConstFst(fst::ReadFstKaldiGeneric(rule)));
Expand All @@ -56,7 +93,8 @@ TextNormalizer::TextNormalizer(const std::string &rule) {
TextNormalizer::TextNormalizer(std::unique_ptr<fst::StdConstFst> rule)
: rule_(std::move(rule)) {}

std::string TextNormalizer::Normalize(const std::string &s) const {
std::string TextNormalizer::Normalize(const std::string &s,
bool remove_output_zero /*=true*/) const {
// Step 1: Convert the input text into an FST
fst::StdVectorFst text = StringToFst(s);

Expand All @@ -68,13 +106,7 @@ std::string TextNormalizer::Normalize(const std::string &s) const {
fst::StdVectorFst one_best;
fst::ShortestPath(composed_fst, &one_best, 1);

// Step 4: Concatenate the output labels of the best path
fst::StringPrinter<fst::StdArc> string_printer(fst::StringTokenType::BYTE);

std::string normalized;
string_printer(one_best, &normalized);

return normalized;
return FstToString(one_best, remove_output_zero);
}

} // namespace kaldifst
6 changes: 5 additions & 1 deletion kaldifst/csrc/text-normalizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ class TextNormalizer {

explicit TextNormalizer(std::unique_ptr<fst::StdConstFst> rule);

std::string Normalize(const std::string &s) const;
// @param s The input text to be normalized
// @param remove_output_zero True to remove bytes whose value is 0 from the
// output.
std::string Normalize(const std::string &s,
bool remove_output_zero = true) const;

private:
std::unique_ptr<fst::StdConstFst> rule_;
Expand Down
6 changes: 4 additions & 2 deletions kaldifst/python/csrc/text-normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ void PybindTextNormalizer(py::module *m) {
using PyClass = TextNormalizer;
py::class_<PyClass>(*m, "TextNormalizer")
.def(py::init<const std::string &>(), py::arg("rule"))
.def("normalize", &PyClass::Normalize)
.def("__call__", &PyClass::Normalize);
.def("normalize", &PyClass::Normalize, py::arg("s"),
py::arg("remove_output_zero") = true)
.def("__call__", &PyClass::Normalize, py::arg("s"),
py::arg("remove_output_zero") = true);
}

} // namespace kaldifst
Loading