Skip to content

Commit

Permalink
filterFinals method for biltrans
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-martian committed Aug 29, 2024
1 parent a67d1e0 commit 906fd46
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 26 deletions.
1 change: 0 additions & 1 deletion lttoolbox/fst_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,6 @@ FSTProcessor::analysis(InputFile& input, UFILE *output)
bool last_incond = false;
bool last_postblank = false;
bool last_preblank = false;
//State current_state = initial_state;
ReusableState current_state;
current_state.init(&root);
UString lf; // analysis (lexical form and tags)
Expand Down
93 changes: 68 additions & 25 deletions lttoolbox/reusable_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,25 @@ void ReusableState::extract(size_t pos, UString& result, double& weight,
}
}

void NFinals(std::vector<std::pair<double, UString>>& results,
bool comp_pair(const std::pair<UString, double>& a,
const std::pair<UString, double>& b)
{
return a.second < b.second;
}

void NFinals(std::vector<std::pair<UString, double>>& results,
size_t maxAnalyses, size_t maxWeightClasses)
{
if (results.empty()) return;
sort(results.begin(), results.end());
sort(results.begin(), results.end(), comp_pair);
if (maxAnalyses < results.size()) {
results.erase(results.begin()+maxAnalyses, results.end());
}
if (maxWeightClasses < results.size()) {
double last_weight = results[0].first + 1;
double last_weight = results[0].second + 1;
for (size_t i = 0; i < results.size(); i++) {
if (results[i].first != last_weight) {
last_weight = results[i].first;
if (results[i].second != last_weight) {
last_weight = results[i].second;
if (maxWeightClasses == 0) {
results.erase(results.begin()+i, results.end());
return;
Expand All @@ -264,16 +270,13 @@ void NFinals(std::vector<std::pair<double, UString>>& results,
}
}

UString ReusableState::filterFinals(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool display_weights,
int max_analyses, int max_weight_classes,
bool uppercase, bool firstupper,
int firstchar) const
void ReusableState::gatherFinals(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool uppercase, bool firstupper,
int firstchar,
std::vector<std::pair<UString, double>>& results) const
{
std::vector<std::pair<double, UString>> results;

UString temp;
double weight;
for (size_t i = start; i < end; i++) {
Expand All @@ -286,29 +289,69 @@ UString ReusableState::filterFinals(const std::map<Node*, double>& finals,
int idx = (temp[firstchar] == '~' ? firstchar + 1 : firstchar);
temp[idx] = u_toupper(temp[idx]);
}
results.push_back({weight, temp});
results.push_back({temp, weight});
}
}
}

void appendWeight(UString& s, double w, bool display_weights) {
if (!display_weights) return;
UChar wbuf[16]{};
// if anyone wants a weight of 10000, this will not be enough
u_sprintf(wbuf, "<W:%f>", w);
s += wbuf;
}

UString ReusableState::filterFinals(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool display_weights,
int max_analyses, int max_weight_classes,
bool uppercase, bool firstupper,
int firstchar) const
{
std::vector<std::pair<UString, double>> results;
gatherFinals(finals, alphabet, escaped_chars, uppercase, firstupper, firstchar,
results);
NFinals(results, max_analyses, max_weight_classes);

temp.clear();
UString temp;
std::set<UString> seen;
for (auto& it : results) {
if (seen.find(it.second) != seen.end()) continue;
seen.insert(it.second);
if (seen.find(it.first) != seen.end()) continue;
seen.insert(it.first);
temp += '/';
temp += it.second;
if (display_weights) {
UChar wbuf[16]{};
// if anyone wants a weight of 10000, this will not be enough
u_sprintf(wbuf, "<W:%f>", it.first);
temp += wbuf;
}
temp += it.first;
appendWeight(temp, it.second, display_weights);
}
return temp;
}

std::vector<UString>
ReusableState::filterFinalsArray(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool display_weights,
int max_analyses, int max_weight_classes,
bool uppercase, bool firstupper,
int firstchar) const
{
std::vector<std::pair<UString, double>> results;
gatherFinals(finals, alphabet, escaped_chars, uppercase, firstupper, firstchar,
results);
NFinals(results, max_analyses, max_weight_classes);

std::set<UString> seen;
std::vector<UString> ret;
for (auto& it : results) {
if (seen.find(it.first) != seen.end()) continue;
seen.insert(it.first);
ret.push_back(it.first);
appendWeight(ret.back(), it.second, display_weights);
}
return ret;
}

bool ReusableState::lastPartHasRequiredSymbol(size_t pos, int32_t symbol,
int32_t separator)
{
Expand Down
14 changes: 14 additions & 0 deletions lttoolbox/reusable_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class ReusableState {
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars, bool uppercase) const;

void gatherFinals(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool uppercase, bool firstupper, int firstchar,
std::vector<std::pair<UString, double>>& results) const;

public:
ReusableState();
~ReusableState();
Expand Down Expand Up @@ -69,6 +75,14 @@ class ReusableState {
int max_analyses, int max_weight_classes,
bool uppercase, bool firstupper,
int firstchar = 0) const;
std::vector<UString>
filterFinalsArray(const std::map<Node*, double>& finals,
const Alphabet& alphabet,
const std::set<UChar32>& escaped_chars,
bool display_weights,
int max_analyses, int max_weight_classes,
bool uppercase, bool firstupper,
int firstchar = 0) const;

bool lastPartHasRequiredSymbol(size_t pos, int32_t symbol, int32_t separator);
bool hasSymbol(int32_t symbol);
Expand Down

0 comments on commit 906fd46

Please sign in to comment.