Skip to content

Commit

Permalink
Merge branch 'espnet:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
roshansh-cmu authored Mar 8, 2022
2 parents de5e713 + cb8181a commit ab2fa25
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 43 deletions.
69 changes: 43 additions & 26 deletions egs2/TEMPLATE/mt1/mt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1165,37 +1165,54 @@ if ! "${skip_eval}"; then
_scoredir="${_dir}/score_bleu"
mkdir -p "${_scoredir}"

paste \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${cleaner}" \
) \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/ref.trn.org"
<"${_data}/text.${tgt_case}.${tgt_lang}" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${cleaner}" \
>"${_scoredir}/ref.trn"

#paste \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" \
# ${python} -m espnet2.bin.tokenize_text \
# -f 2- --input - --output - \
# --token_type word \
# --non_linguistic_symbols "${nlsyms_txt}" \
# --remove_non_linguistic_symbols true \
# --cleaner "${cleaner}" \
# ) \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
# >"${_scoredir}/ref.trn.org"

# NOTE(kamo): Don't use cleaner for hyp
paste \
<(<"${_dir}/text" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
) \
<(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/hyp.trn.org"
<"${_dir}/text" \
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
>"${_scoredir}/hyp.trn"

#paste \
# <(<"${_dir}/text" \
# ${python} -m espnet2.bin.tokenize_text \
# -f 2- --input - --output - \
# --token_type word \
# --non_linguistic_symbols "${nlsyms_txt}" \
# --remove_non_linguistic_symbols true \
# ) \
# <(<"${_data}/text.${tgt_case}.${tgt_lang}" awk '{ print "(" $2 "-" $1 ")" }') \
# >"${_scoredir}/hyp.trn.org"

# remove utterance id
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org" > "${_scoredir}/ref.trn"
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"
#perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org" > "${_scoredir}/ref.trn"
#perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"

# detokenizer
detokenizer.perl -l en -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l en -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"

if [ ${tgt_case} = "tc" ]; then
echo "Case sensitive BLEU result (single-reference)" >> ${_scoredir}/result.tc.txt
Expand Down Expand Up @@ -1238,7 +1255,7 @@ if ! "${skip_eval}"; then

#
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org.${ref_idx}" > "${_scoredir}/ref.trn.${ref_idx}"
detokenizer.perl -l en -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
remove_punctuation.pl < "${_scoredir}/ref.trn.detok.${ref_idx}" > "${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
case_sensitive_refs="${case_sensitive_refs} ${_scoredir}/ref.trn.detok.${ref_idx}"
case_insensitive_refs="${case_insensitive_refs} ${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
Expand Down
32 changes: 16 additions & 16 deletions egs2/TEMPLATE/st1/st.sh
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,8 @@ fi
# Extra files for translation process
utt_extra_files="text.${src_case}.${src_lang} text.${tgt_case}.${tgt_lang}"
# Use the same text as ST for bpe training if not specified.
if "${token_joint}"; then
# if token_joint, the bpe training will use both src_lang and tgt_lang to train a single bpe model
[ -z "${src_bpe_train_text}" ] && src_bpe_train_text="${data_feats}/${train_set}/text.${src_case}.${src_lang}"
[ -z "${tgt_bpe_train_text}" ] && tgt_bpe_train_text="${data_feats}/${train_set}/text.${tgt_case}.${tgt_lang}"

# Prepare data as text.${src_lang}_${tgt_lang})
cat $src_bpe_train_text $tgt_bpe_train_text > ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
tgt_bpe_train_text="${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}"
else
[ -z "${src_bpe_train_text}" ] && src_bpe_train_text="${data_feats}/${train_set}/text.${src_case}.${src_lang}"
[ -z "${tgt_bpe_train_text}" ] && tgt_bpe_train_text="${data_feats}/${train_set}/text.${tgt_case}.${tgt_lang}"
fi
[ -z "${src_bpe_train_text}" ] && src_bpe_train_text="${data_feats}/${train_set}/text.${src_case}.${src_lang}"
[ -z "${tgt_bpe_train_text}" ] && tgt_bpe_train_text="${data_feats}/${train_set}/text.${tgt_case}.${tgt_lang}"
# Use the same text as ST for lm training if not specified.
[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text.${tgt_case}.${tgt_lang}"
# Use the same text as ST for lm training if not specified.
Expand Down Expand Up @@ -743,6 +733,16 @@ if ! "${skip_data_prep}"; then
fi

if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# Combine source and target texts when using joint tokenization
if "${token_joint}"; then
log "Merge src and target data if joint BPE"

cat $tgt_bpe_train_text > ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
[ ! -z "${src_bpe_train_text}" ] && cat ${src_bpe_train_text} >> ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
# Set the new text as the target text
tgt_bpe_train_text="${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}"
fi

# First generate tgt lang
if [ "${tgt_token_type}" = bpe ]; then
log "Stage 5a: Generate token_list from ${tgt_bpe_train_text} using BPE for tgt_lang"
Expand Down Expand Up @@ -1484,8 +1484,8 @@ if ! "${skip_eval}"; then
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/hyp.trn.org" > "${_scoredir}/hyp.trn"

# detokenizer
detokenizer.perl -l en -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l en -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn" > "${_scoredir}/ref.trn.detok"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/hyp.trn" > "${_scoredir}/hyp.trn.detok"

if [ ${tgt_case} = "tc" ]; then
echo "Case sensitive BLEU result (single-reference)" >> ${_scoredir}/result.tc.txt
Expand Down Expand Up @@ -1528,7 +1528,7 @@ if ! "${skip_eval}"; then

#
perl -pe 's/\([^\)]+\)//g;' "${_scoredir}/ref.trn.org.${ref_idx}" > "${_scoredir}/ref.trn.${ref_idx}"
detokenizer.perl -l en -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
detokenizer.perl -l ${tgt_lang} -q < "${_scoredir}/ref.trn.${ref_idx}" > "${_scoredir}/ref.trn.detok.${ref_idx}"
remove_punctuation.pl < "${_scoredir}/ref.trn.detok.${ref_idx}" > "${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
case_sensitive_refs="${case_sensitive_refs} ${_scoredir}/ref.trn.detok.${ref_idx}"
case_insensitive_refs="${case_insensitive_refs} ${_scoredir}/ref.trn.detok.lc.rm.${ref_idx}"
Expand All @@ -1551,7 +1551,7 @@ if ! "${skip_eval}"; then
done

# Show results in Markdown syntax
scripts/utils/show_st_result.sh --case $tgt_case "${st_exp}" > "${st_exp}"/RESULTS.md
scripts/utils/show_translation_result.sh --case $tgt_case "${st_exp}" > "${st_exp}"/RESULTS.md
cat "${cat_exp}"/RESULTS.md
fi
else
Expand Down
6 changes: 5 additions & 1 deletion espnet2/bin/mt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def __call__(
assert isinstance(hyp, Hypothesis), type(hyp)

# remove sos/eos and get results
token_int = hyp.yseq[1:-1].tolist()
# token_int = hyp.yseq[1:-1].tolist()
# TODO(sdalmia): check why the above line doesn't work
token_int = hyp.yseq.tolist()
token_int = list(filter(lambda x: x != self.mt_model.sos, token_int))
token_int = list(filter(lambda x: x != self.mt_model.eos, token_int))

# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
Expand Down

0 comments on commit ab2fa25

Please sign in to comment.