Skip to content

Commit

Permalink
Merge pull request espnet#4146 from siddalmia/fix_st_mt_scoring2
Browse files Browse the repository at this point in the history
scoring fixes MT and ST
  • Loading branch information
sw005320 authored Mar 8, 2022
2 parents 537514a + 041e132 commit cb8181a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 31 deletions.
69 changes: 43 additions & 26 deletions egs2/TEMPLATE/mt1/mt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1165,37 +1165,54 @@ if ! "${skip_eval}"; then
_scoredir="${_dir}/score_bleu"
mkdir -p "${_scoredir}"

paste \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${cleaner}" \
) \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/ref.trn.org"
<"${_data}/text.${tgt_case}.${tgt_lang}" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${cleaner}" \
>"${_scoredir}/ref.trn"

#paste \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" \
# ${python} -m espnet2.bin.tokenize_text \
# -f 2- --input - --output - \
# --token_type word \
# --non_linguistic_symbols "${nlsyms_txt}" \
# --remove_non_linguistic_symbols true \
# --cleaner "${cleaner}" \
# ) \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
# >"${_scoredir}/ref.trn.org"

# NOTE(kamo): Don't use cleaner for hyp
paste \
<(<"${_dir}/text" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
) \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/hyp.trn.org"
<"${_dir}/text" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
>"${_scoredir}/hyp.trn"

#paste \
# <(<"${_dir}/text" \
# ${python} -m espnet2.bin.tokenize_text \
# -f 2- --input - --output - \
# --token_type word \
# --non_linguistic_symbols "${nlsyms_txt}" \
# --remove_non_linguistic_symbols true \
# ) \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
# >"${_scoredir}/hyp.trn.org"

# remove utterance id
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org" > "${_scoredir}/ref.trn"
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"
#perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org" > "${_scoredir}/ref.trn"
#perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"

# detokenizer
detokenizer.perl -l en -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l en -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"

if [ ${tgt_case} = "tc" ]; then
echo "Case sensitive BLEU result (single-reference)" >> ${_scoredir}/result.tc.txt
Expand Down Expand Up @@ -1238,7 +1255,7 @@ if ! "${skip_eval}"; then

#
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org.${ref_idx}" > "${_scoredir}/ref.trn.${ref_idx}"
detokenizer.perl -l en -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
remove_punctuation.pl < "${_scoredir}/ref.trn.detok.${ref_idx}" > "${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
case_sensitive_refs="${case_sensitive_refs} ${_scoredir}/ref.trn.detok.${ref_idx}"
case_insensitive_refs="${case_insensitive_refs} ${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
Expand Down
8 changes: 4 additions & 4 deletions egs2/TEMPLATE/st1/st.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1484,8 +1484,8 @@ if ! "${skip_eval}"; then
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"

# detokenizer
detokenizer.perl -l en -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l en -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"

if [ ${tgt_case} = "tc" ]; then
echo "Case sensitive BLEU result (single-reference)" >> ${_scoredir}/result.tc.txt
Expand Down Expand Up @@ -1528,7 +1528,7 @@ if ! "${skip_eval}"; then

#
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org.${ref_idx}" > "${_scoredir}/ref.trn.${ref_idx}"
detokenizer.perl -l en -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
remove_punctuation.pl < "${_scoredir}/ref.trn.detok.${ref_idx}" > "${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
case_sensitive_refs="${case_sensitive_refs} ${_scoredir}/ref.trn.detok.${ref_idx}"
case_insensitive_refs="${case_insensitive_refs} ${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
Expand All @@ -1551,7 +1551,7 @@ if ! "${skip_eval}"; then
done

# Show results in Markdown syntax
scripts/utils/show_st_result.sh --case $tgt_case "${st_exp}" > "${st_exp}"/RESULTS.md
scripts/utils/show_translation_result.sh --case $tgt_case "${st_exp}" > "${st_exp}"/RESULTS.md
cat "${cat_exp}"/RESULTS.md
fi
else
Expand Down
6 changes: 5 additions & 1 deletion espnet2/bin/mt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def __call__(
assert isinstance(hyp, Hypothesis), type(hyp)

# remove sos/eos and get results
token_int = hyp.yseq[1:-1].tolist()
# token_int = hyp.yseq[1:-1].tolist()
# TODO(sdalmia): check why the above line doesn't work
token_int = hyp.yseq.tolist()
token_int = list(filter(lambda x: x != self.mt_model.sos, token_int))
token_int = list(filter(lambda x: x != self.mt_model.eos, token_int))

# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
Expand Down

0 comments on commit cb8181a

Please sign in to comment.