From bd937376bd46dd2ce884a8f801141660333529cb Mon Sep 17 00:00:00 2001 From: Eugene Vignanker <39065538+eugenevignanker@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:01:38 -0700 Subject: [PATCH] FSTALIGN-63: Preserving inserts in NLP output (#48) Co-authored-by: Eugene Vignanker --- src/fstalign.cpp | 14 +++++++++----- src/fstalign.h | 4 ++-- src/main.cpp | 12 +++++++----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 4509533..28eaa6f 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -555,7 +555,7 @@ void align_stitches_to_nlp(NlpFstLoader& refLoader, vector &stitches) } } -void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_file, Json::Value norm_json) { +void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_file, Json::Value norm_json, bool add_inserts = false) { auto logger = logger::GetOrCreateLogger("fstalign"); logger->info("Writing nlp output"); // write header; 'comment' is there to store information about how well the alignment went @@ -565,7 +565,7 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil << endl; for (auto &stitch : stitches) { // if the comment starts with 'ins' - if (stitch.comment.find("ins") == 0) { + if (stitch.comment.find("ins") == 0 && !add_inserts) { // there's no nlp row info for such case, let's skip over it if (stitch.confidence >= 1) { logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts); @@ -596,6 +596,10 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil } ref_tk = original_nlp_token; + } else if (stitch.comment.find("ins") == 0) { + logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment); + ref_tk = ""; + stitch.comment = "ins(" + stitch.hyptk + ")"; } if (ref_tk == NOOP) { @@ -627,8 +631,8 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil } } -void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, string output_sbs, string output_nlp, - AlignerOptions alignerOptions) { +void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp, + AlignerOptions alignerOptions, bool add_inserts_nlp) { // int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename, // string composition_approach, bool record_case_stats) { auto logger = logger::GetOrCreateLogger("fstalign"); @@ -687,7 +691,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine if (!output_nlp.empty()) { ofstream nlp_ostream(output_nlp); - write_stitches_to_nlp(stitches, nlp_ostream, nlp_ref_loader->mJsonNorm); + write_stitches_to_nlp(stitches, nlp_ostream, nlp_ref_loader->mJsonNorm, add_inserts_nlp); } } diff --git a/src/fstalign.h b/src/fstalign.h index 0fd60ca..929c220 100644 --- a/src/fstalign.h +++ b/src/fstalign.h @@ -50,8 +50,8 @@ struct AlignerOptions { // void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file, // int numBests, string symbols_filename, string composition_approach); -void HandleWer(FstLoader &refLoader, FstLoader &hypLoader, SynonymEngine &engine, string output_sbs, string output_nlp, - AlignerOptions alignerOptions); +void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp, + AlignerOptions alignerOptions, bool add_inserts_nlp = false); void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file, AlignerOptions alignerOptions); diff --git a/src/main.cpp b/src/main.cpp index 6cfb8ce..c2d63fd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -34,6 +34,7 @@ int main(int argc, char **argv) { bool record_case_stats = false; bool use_punctuation = false; bool disable_approximate_alignment = false; + bool add_inserts_nlp = false; bool disable_cutoffs = false; bool disable_hyphen_ignore = false; @@ -122,6 +123,7 @@ int main(int argc, char **argv) { "Record precision/recall for how well the hypothesis" "casing matches the reference."); get_wer->add_flag("--use-punctuation", use_punctuation, "Treat punctuation from nlp rows as separate tokens"); + get_wer->add_flag("--add-inserts-nlp", add_inserts_nlp, "Add inserts to NLP output"); // CLI11_PARSE(app, argc, argv); try { @@ -148,12 +150,12 @@ int main(int argc, char **argv) { FSTALIGNER_VERSION_PATCH); auto subcommand = app.get_subcommands()[0]; - auto command = subcommand->get_name(); - + auto command = subcommand->get_name(); + // loading "reference" inputs std::unique_ptr hyp = FstLoader::MakeHypothesisLoader(hyp_filename, hyp_json_norm_filename, use_punctuation, !symbols_filename.empty()); - std::unique_ptr ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, !symbols_filename.empty()); + std::unique_ptr ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, !symbols_filename.empty()); AlignerOptions alignerOptions; alignerOptions.speaker_switch_context_size = speaker_switch_context_size; @@ -176,7 +178,7 @@ int main(int argc, char **argv) { } if (command == "wer") { - HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions); + HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp); } else if (command == "align") { if (output_nlp.empty()) { console->error("the output nlp file must be specified"); @@ -345,4 +347,4 @@ std::unique_ptr FstLoader::MakeHypothesisLoader(const std::string& hy hypOneBest->LoadTextFile(hyp_filename); return hypOneBest; } -} +}