Skip to content

Commit

Permalink
FSTALIGN-63: Preserving inserts in NLP output (#48)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Vignanker <eugene.vignanekr@rev.com>
  • Loading branch information
eugenevignanker and Eugene Vignanker authored Sep 11, 2023
1 parent 5f81a70 commit bd93737
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
14 changes: 9 additions & 5 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ void align_stitches_to_nlp(NlpFstLoader& refLoader, vector<Stitching> &stitches)
}
}

void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_file, Json::Value norm_json) {
void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_file, Json::Value norm_json, bool add_inserts = false) {
auto logger = logger::GetOrCreateLogger("fstalign");
logger->info("Writing nlp output");
// write header; 'comment' is there to store information about how well the alignment went
Expand All @@ -565,7 +565,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
<< endl;
for (auto &stitch : stitches) {
// if the comment starts with 'ins'
if (stitch.comment.find("ins") == 0) {
if (stitch.comment.find("ins") == 0 && !add_inserts) {
// there's no nlp row info for such case, let's skip over it
if (stitch.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts);
Expand Down Expand Up @@ -596,6 +596,10 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}

ref_tk = original_nlp_token;
} else if (stitch.comment.find("ins") == 0) {
logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment);
ref_tk = "";
stitch.comment = "ins(" + stitch.hyptk + ")";
}

if (ref_tk == NOOP) {
Expand Down Expand Up @@ -627,8 +631,8 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}
}

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, string output_sbs, string output_nlp,
AlignerOptions alignerOptions) {
void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp) {
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats) {
auto logger = logger::GetOrCreateLogger("fstalign");
Expand Down Expand Up @@ -687,7 +691,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine

if (!output_nlp.empty()) {
ofstream nlp_ostream(output_nlp);
write_stitches_to_nlp(stitches, nlp_ostream, nlp_ref_loader->mJsonNorm);
write_stitches_to_nlp(stitches, nlp_ostream, nlp_ref_loader->mJsonNorm, add_inserts_nlp);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ struct AlignerOptions {
// void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file,
// int numBests, string symbols_filename, string composition_approach);

void HandleWer(FstLoader &refLoader, FstLoader &hypLoader, SynonymEngine &engine, string output_sbs, string output_nlp,
AlignerOptions alignerOptions);
void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp = false);
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

Expand Down
12 changes: 7 additions & 5 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ int main(int argc, char **argv) {
bool record_case_stats = false;
bool use_punctuation = false;
bool disable_approximate_alignment = false;
bool add_inserts_nlp = false;

bool disable_cutoffs = false;
bool disable_hyphen_ignore = false;
Expand Down Expand Up @@ -122,6 +123,7 @@ int main(int argc, char **argv) {
"Record precision/recall for how well the hypothesis"
"casing matches the reference.");
get_wer->add_flag("--use-punctuation", use_punctuation, "Treat punctuation from nlp rows as separate tokens");
get_wer->add_flag("--add-inserts-nlp", add_inserts_nlp, "Add inserts to NLP output");

// CLI11_PARSE(app, argc, argv);
try {
Expand All @@ -148,12 +150,12 @@ int main(int argc, char **argv) {
FSTALIGNER_VERSION_PATCH);

auto subcommand = app.get_subcommands()[0];
auto command = subcommand->get_name();
auto command = subcommand->get_name();


// loading "reference" inputs
std::unique_ptr<FstLoader> hyp = FstLoader::MakeHypothesisLoader(hyp_filename, hyp_json_norm_filename, use_punctuation, !symbols_filename.empty());
std::unique_ptr<FstLoader> ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, !symbols_filename.empty());
std::unique_ptr<FstLoader> ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, !symbols_filename.empty());

AlignerOptions alignerOptions;
alignerOptions.speaker_switch_context_size = speaker_switch_context_size;
Expand All @@ -176,7 +178,7 @@ int main(int argc, char **argv) {
}

if (command == "wer") {
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions);
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp);
} else if (command == "align") {
if (output_nlp.empty()) {
console->error("the output nlp file must be specified");
Expand Down Expand Up @@ -345,4 +347,4 @@ std::unique_ptr<FstLoader> FstLoader::MakeHypothesisLoader(const std::string& hy
hypOneBest->LoadTextFile(hyp_filename);
return hypOneBest;
}
}
}

0 comments on commit bd93737

Please sign in to comment.