Skip to content

Commit

Permalink
Merge pull request #22 from Techievena/issue2
Browse files Browse the repository at this point in the history
Implementation of weights in lttoolbox
  • Loading branch information
TinoDidriksen authored Jul 31, 2018
2 parents c15afd1 + 21580c0 commit fd28b8e
Show file tree
Hide file tree
Showing 50 changed files with 1,723 additions and 1,029 deletions.
6 changes: 3 additions & 3 deletions lttoolbox/alphabet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ Alphabet::read(FILE *input)
int first = Compression::multibyte_read(input);
int second = Compression::multibyte_read(input);
pair<int, int> tmp(first - bias, second - bias);
int spair_size = a_new.spair.size();
int spair_size = a_new.spair.size();
a_new.spair[tmp] = spair_size;
a_new.spairinv.push_back(tmp);
}

*this = a_new;
}

Expand Down Expand Up @@ -303,7 +303,7 @@ Alphabet::createLoopbackSymbols(set<int> &symbols, Alphabet &basis, Side s, bool
it++)
{
// Only include tags that were actually seen on the correct side
if(tags.find(it->second) != tags.end())
if(tags.find(it->second) != tags.end())
{
includeSymbol(it->first);
symbols.insert(operator()(operator()(it->first),
Expand Down
2 changes: 1 addition & 1 deletion lttoolbox/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Alphabet
* Symbol-identifier relationship. Only contains <tags>.
* @see slexicinv
*/
map<wstring, int, Ltstr> slexic;
map<wstring, int, Ltstr> slexic;

/**
* Identifier-symbol relationship. Only contains <tags>.
Expand Down
121 changes: 69 additions & 52 deletions lttoolbox/att_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
using namespace std;

AttCompiler::AttCompiler() :
starting_state(0)
starting_state(0),
default_weight(0.0000)
{
}

Expand All @@ -33,7 +34,7 @@ AttCompiler::~AttCompiler()
}

void
AttCompiler::clear()
AttCompiler::clear()
{
for (map<int, AttNode*>::const_iterator it = graph.begin(); it != graph.end();
++it)
Expand All @@ -49,14 +50,14 @@ AttCompiler::clear()
* @todo Are there other special symbols? If so, add them, and maybe use a map
* for conversion?
*/
void
AttCompiler::convert_hfst(wstring& symbol)
void
AttCompiler::convert_hfst(wstring& symbol)
{
if (symbol == L"@0@" || symbol == L"ε")
if (symbol == L"@0@" || symbol == L"ε")
{
symbol = L"";
}
else if (symbol == L"@_SPACE_@")
else if (symbol == L"@_SPACE_@")
{
symbol = L" ";
}
Expand All @@ -75,7 +76,7 @@ AttCompiler::is_word_punct(wchar_t symbol)
return true;
}

return false;
return false;
}

/**
Expand All @@ -88,7 +89,7 @@ AttCompiler::is_word_punct(wchar_t symbol)
* only) character otherwise.
*/
int
AttCompiler::symbol_code(const wstring& symbol)
AttCompiler::symbol_code(const wstring& symbol)
{
if (symbol.length() > 1) {
alphabet.includeSymbol(symbol);
Expand All @@ -111,8 +112,8 @@ AttCompiler::symbol_code(const wstring& symbol)
}
}

void
AttCompiler::parse(string const &file_name, wstring const &dir)
void
AttCompiler::parse(string const &file_name, wstring const &dir)
{
clear();

Expand All @@ -121,11 +122,12 @@ AttCompiler::parse(string const &file_name, wstring const &dir)
wstring line;
bool first_line = true; // First line -- see below
bool seen_input_symbol = false;
while (getline(infile, line))
while (getline(infile, line))
{
tokens.clear();
int from, to;
wstring upper, lower;
double weight;

if (line.length() == 0 && first_line)
{
Expand All @@ -139,30 +141,38 @@ AttCompiler::parse(string const &file_name, wstring const &dir)
}

/* Empty line. */
if (line.length() == 0)
if (line.length() == 0)
{
continue;
}
split(line, L'\t', tokens);

from = convert(tokens[0]);
from = static_cast<int>(convert(tokens[0]));

AttNode* source = get_node(from);
/* First line: the initial state is of both types. */
if (first_line)
if (first_line)
{
starting_state = from;
first_line = false;
}

/* Final state. */
if (tokens.size() <= 2)
if (tokens.size() <= 2)
{
finals.insert(from);
if (tokens.size() > 1)
{
weight = static_cast<double>(convert(tokens[1]));
}
else
{
weight = default_weight;
}
finals.insert(pair <int, double>(from, weight));
}
else
else
{
to = convert(tokens[1]);
to = static_cast<int>(convert(tokens[1]));
if(dir == L"RL")
{
upper = tokens[3];
Expand All @@ -177,17 +187,24 @@ AttCompiler::parse(string const &file_name, wstring const &dir)
convert_hfst(lower);
if(upper != L"")
{
seen_input_symbol = true;
seen_input_symbol = true;
}
/* skip lines that have an empty left side and output
/* skip lines that have an empty left side and output
if we haven't seen an input symbol */
if(upper == L"" && lower != L"" && !seen_input_symbol)
if(upper == L"" && lower != L"" && !seen_input_symbol)
{
continue;
}
int tag = alphabet(symbol_code(upper), symbol_code(lower));
/* We don't read the weights, even if they are defined. */
source->transductions.push_back(Transduction(to, upper, lower, tag));
if(tokens.size() > 4)
{
weight = static_cast<double>(convert(tokens[4]));
}
else
{
weight = default_weight;
}
source->transductions.push_back(Transduction(to, upper, lower, tag, weight));

get_node(to);
}
Expand All @@ -202,7 +219,7 @@ AttCompiler::parse(string const &file_name, wstring const &dir)

/** Extracts the sub-transducer made of states of type @p type. */
Transducer
AttCompiler::extract_transducer(TransducerType type)
AttCompiler::extract_transducer(TransducerType type)
{
Transducer transducer;
/* Correlation between the graph's state ids and those in the transducer. */
Expand All @@ -214,11 +231,11 @@ AttCompiler::extract_transducer(TransducerType type)

/* The final states. */
bool noFinals = true;
for (set<int>::const_iterator f = finals.begin(); f != finals.end(); ++f)
for (map<int, double>::const_iterator f = finals.begin(); f != finals.end(); ++f)
{
if (corr.find(*f) != corr.end())
if (corr.find(f->first) != corr.end())
{
transducer.setFinal(corr[*f]);
transducer.setFinal(corr[f->first], f->second);
noFinals = false;
}
}
Expand All @@ -229,7 +246,7 @@ AttCompiler::extract_transducer(TransducerType type)
wcerr << L"No final states (" << type << ")" << endl;
wcerr << L" were:" << endl;
wcerr << L"\t" ;
for (set<int>::const_iterator f = finals.begin(); f != finals.end(); ++f)
for (set<int>::const_iterator f = finals.begin(); f != finals.end(); ++f)
{
wcerr << *f << L" ";
}
Expand All @@ -245,14 +262,14 @@ AttCompiler::extract_transducer(TransducerType type)
*/
void
AttCompiler::_extract_transducer(TransducerType type, int from,
Transducer& transducer, map<int, int>& corr,
set<int>& visited)
Transducer& transducer, map<int, int>& corr,
set<int>& visited)
{
if (visited.find(from) != visited.end())
if (visited.find(from) != visited.end())
{
return;
}
else
else
{
visited.insert(from);
}
Expand All @@ -266,30 +283,30 @@ AttCompiler::_extract_transducer(TransducerType type, int from,
for (vector<Transduction>::const_iterator it = source->transductions.begin();
it != source->transductions.end(); ++it)
{
if ((it->type & type) != type)
if ((it->type & type) != type)
{
continue; // Not the right type
}
/* Is the target state new? */
bool new_to = corr.find(it->to) == corr.end();

if (new_from)
if (new_from)
{
corr[from] = transducer.size() + (new_to ? 1 : 0);
}
from_t = corr[from];

/* Now with the target state: */
if (!new_to)
if (!new_to)
{
/* We already know it, possibly by a different name: link them! */
to_t = corr[it->to];
transducer.linkStates(from_t, to_t, it->tag);
}
else
transducer.linkStates(from_t, to_t, it->tag, it->weight);
}
else
{
/* We haven't seen it yet: add a new state! */
to_t = transducer.insertNewSingleTransduction(it->tag, from_t);
to_t = transducer.insertNewSingleTransduction(it->tag, from_t, it->weight);
corr[it->to] = to_t;
}
_extract_transducer(type, it->to, transducer, corr, visited);
Expand All @@ -310,31 +327,31 @@ AttCompiler::_extract_transducer(TransducerType type, int from,
* @param visited the ids of states visited by this path.
* @param path are we in a path?
*/
void
void
AttCompiler::classify(int from, map<int, TransducerType>& visited, bool path,
TransducerType type)
TransducerType type)
{
AttNode* source = get_node(from);
if (visited.find(from) != visited.end())
if (visited.find(from) != visited.end())
{
if (path && ( (visited[from] & type) == type) )
if (path && ( (visited[from] & type) == type) )
{
return;
}
}

if (path)
if (path)
{
visited[from] |= type;
}

for (vector<Transduction>::iterator it = source->transductions.begin();
it != source->transductions.end(); ++it)
it != source->transductions.end(); ++it)
{
bool next_path = path;
int next_type = type;
bool first_transition = !path && it->upper != L"";
if (first_transition)
if (first_transition)
{
/* First transition: we now know the type of the path! */
bool upper_word = (it->upper.length() == 1 &&
Expand All @@ -344,8 +361,8 @@ AttCompiler::classify(int from, map<int, TransducerType>& visited, bool path,
if (upper_word) next_type |= WORD;
if (upper_punct) next_type |= PUNCT;
next_path = true;
}
else
}
else
{
/* Otherwise (not yet, already): target's type is the same as ours. */
next_type = type;
Expand All @@ -357,7 +374,7 @@ AttCompiler::classify(int from, map<int, TransducerType>& visited, bool path,

/** Writes the transducer to @p file_name in lt binary format. */
void
AttCompiler::write(FILE *output)
AttCompiler::write(FILE *output)
{
// FILE* output = fopen(file_name, "w");
Transducer punct_fst = extract_transducer(PUNCT);
Expand All @@ -377,13 +394,13 @@ AttCompiler::write(FILE *output)
}
Compression::wstring_write(L"main@standard", output);
Transducer word_fst = extract_transducer(WORD);
word_fst.write(output);
word_fst.write(output, 0, true);
wcout << L"main@standard" << " " << word_fst.size();
wcout << " " << word_fst.numberOfTransitions() << endl;
Compression::wstring_write(L"final@inconditional", output);
if(punct_fst.numberOfTransitions() != 0)
if(punct_fst.numberOfTransitions() != 0)
{
punct_fst.write(output);
punct_fst.write(output, 0, true);
wcout << L"final@inconditional" << " " << punct_fst.size();
wcout << " " << punct_fst.numberOfTransitions() << endl;
}
Expand Down
Loading

0 comments on commit fd28b8e

Please sign in to comment.