Skip to content

Commit

Permalink
Merge pull request espnet#4132 from pyf98/mt
Browse files Browse the repository at this point in the history
IWSLT'14 Results using ESPnet2-MT
  • Loading branch information
ftshijt authored Mar 6, 2022
2 parents 997ebce + 11e3e7c commit bfb23b8
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 67 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ESPnet uses [pytorch](http://pytorch.org/) as a deep learning engine and also fo
- Support numbers of `ASR` recipes (WSJ, Switchboard, CHiME-4/5, Librispeech, TED, CSJ, AMI, HKUST, Voxforge, REVERB, etc.)
- Support numbers of `TTS` recipes with a similar manner to the ASR recipe (LJSpeech, LibriTTS, M-AILABS, etc.)
- Support numbers of `ST` recipes (Fisher-CallHome Spanish, Libri-trans, IWSLT'18, How2, Must-C, Mboshi-French, etc.)
- Support numbers of `MT` recipes (IWSLT'16, the above ST recipes etc.)
- Support numbers of `MT` recipes (IWSLT'14, IWSLT'16, the above ST recipes etc.)
- Support numbers of `SLU` recipes (CATSLU-MAPS, FSC, Grabo, IEMOCAP, JDCINAL, SNIPS, SLURP, SWBD-DA, etc.)
- Support numbers of `SE/SS` recipes (DNS-IS2020, LibriMix, SMS-WSJ, VCTK-noisyreverb, WHAM!, WHAMR!, WSJ-2mix, etc.)
- Support voice conversion recipe (VCC2020 baseline)
Expand Down Expand Up @@ -368,6 +368,7 @@ Available pretrained models in the demo script are listed as below.
| Must-C tst-COMMON (En->De) | 27.63 | [link](https://github.com/espnet/espnet/blob/master/egs/must_c/mt1/RESULTS.md#summary-4-gram-bleu) |
| IWSLT'14 test2014 (En->De) | 24.70 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'14 test2014 (De->En) | 29.22 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'14 test2014 (De->En) | 32.2 | [link](https://github.com/espnet/espnet/blob/master/egs2/iwslt14/mt1/README.md) |
| IWSLT'16 test2014 (En->De) | 24.05 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'16 test2014 (De->En) | 29.13 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
Expand Down
17 changes: 7 additions & 10 deletions egs2/TEMPLATE/mt1/mt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ if ! "${skip_data_prep}"; then

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [ "${feats_type}" = raw ]; then
log "Stage 3: data/ -> ${data_feats}"
log "Stage 2: data/ -> ${data_feats}"

for dset in "${train_set}" "${valid_set}" ${test_sets}; do
if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then
Expand Down Expand Up @@ -508,19 +508,18 @@ if ! "${skip_data_prep}"; then

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then

# Then generate src lang
if "${token_joint}"; then
log "Merge src and target data if joint BPE"

cat $tgt_bpe_train_text > ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
[ -z "${src_bpe_train_text}" ] && cat ${src_bpe_train_text} >> ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
[ ! -z "${src_bpe_train_text}" ] && cat ${src_bpe_train_text} >> ${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}
# Set the new text as the target text
tgt_bpe_train_text="${data_feats}/${train_set}/text.${src_lang}_${tgt_lang}"
fi

# First generate tgt lang
if [ "${tgt_token_type}" = bpe ]; then
log "Stage 5a: Generate token_list from ${tgt_bpe_train_text} using BPE for tgt_lang"
log "Stage 4a: Generate token_list from ${tgt_bpe_train_text} using BPE for tgt_lang"

mkdir -p "${tgt_bpedir}"
# shellcheck disable=SC2002
Expand Down Expand Up @@ -550,7 +549,7 @@ if ! "${skip_data_prep}"; then
} > "${tgt_token_list}"

elif [ "${tgt_token_type}" = char ] || [ "${tgt_token_type}" = word ]; then
log "Stage 5a: Generate character level token_list from ${tgt_bpe_train_text} for tgt_lang"
log "Stage 4a: Generate character level token_list from ${tgt_bpe_train_text} for tgt_lang"

_opts="--non_linguistic_symbols ${nlsyms_txt}"

Expand Down Expand Up @@ -593,10 +592,10 @@ if ! "${skip_data_prep}"; then

# Then generate src lang
if "${token_joint}"; then
log "Stage 5b: Skip separate token construction for src_lang when setting ${token_joint} as true"
log "Stage 4b: Skip separate token construction for src_lang when setting ${token_joint} as true"
else
if [ "${src_token_type}" = bpe ]; then
log "Stage 5b: Generate token_list from ${src_bpe_train_text} using BPE for src_lang"
log "Stage 4b: Generate token_list from ${src_bpe_train_text} using BPE for src_lang"

mkdir -p "${src_bpedir}"
# shellcheck disable=SC2002
Expand Down Expand Up @@ -626,7 +625,7 @@ if ! "${skip_data_prep}"; then
} > "${src_token_list}"

elif [ "${src_token_type}" = char ] || [ "${src_token_type}" = word ]; then
log "Stage 5b: Generate character level token_list from ${src_bpe_train_text} for src_lang"
log "Stage 4b: Generate character level token_list from ${src_bpe_train_text} for src_lang"

_opts="--non_linguistic_symbols ${nlsyms_txt}"

Expand All @@ -650,8 +649,6 @@ if ! "${skip_data_prep}"; then
log "Error: not supported --token_type '${src_token_type}'"
exit 2
fi


fi
fi

Expand Down
14 changes: 14 additions & 0 deletions egs2/iwslt14/mt1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Results

## mt_train_mt_transformer_lr3e-3_warmup10k_share_enc_dec_input_dropout0.3_raw_bpe_tc10000
- mt_config: conf/tuning/train_mt_transformer_lr3e-3_warmup10k_share_enc_dec_input_dropout0.3.yaml
- inference_config: conf/decode_mt.yaml

### BLEU

Metric: BLEU-4, detokenized case-sensitive BLEU result (single-reference)

|dataset|bleu_score|verbose_score|
|---|---|---|
|beam5_maxlenratio1.6_penalty0.2/valid|33.3|68.4/42.9/28.9/19.8 (BP = 0.924 ratio = 0.927 hyp_len = 134328 ref_len = 144976)|
|beam5_maxlenratio1.6_penalty0.2/test|32.2|67.2/41.4/27.4/18.5 (BP = 0.933 ratio = 0.935 hyp_len = 119813 ref_len = 128122)|
8 changes: 4 additions & 4 deletions egs2/iwslt14/mt1/conf/decode_mt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
batch_size: 1
beam_size: 10
nbest: 1
beam_size: 5
lm_weight: 0.0

maxlenratio: 1.6
minlenratio: 0.0
penalty: 0.2
44 changes: 0 additions & 44 deletions egs2/iwslt14/mt1/conf/train_mt_transformer.yaml

This file was deleted.

1 change: 1 addition & 0 deletions egs2/iwslt14/mt1/conf/train_mt_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
frontend: embed # embedding + positional encoding
frontend_conf:
embed_dim: 512
positional_dropout_rate: 0.3

encoder: transformer
encoder_conf:
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 6
dropout_rate: 0.3
positional_dropout_rate: 0.3
attention_dropout_rate: 0.3
input_layer: null
normalize_before: true

decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 1024
num_blocks: 6
dropout_rate: 0.3
positional_dropout_rate: 0.3
self_attention_dropout_rate: 0.3
src_attention_dropout_rate: 0.3

model_conf:
lsm_weight: 0.1
length_normalized_loss: false
share_decoder_input_output_embed: false
share_encoder_decoder_input_embed: true

num_att_plot: 1
log_interval: 100
num_workers: 2
batch_type: numel
batch_bins: 400000000
accum_grad: 1
max_epoch: 200
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10

optim: adam
optim_conf:
lr: 0.003
betas:
- 0.9
- 0.98
eps: 0.000000001
weight_decay: 0.0001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 10000
15 changes: 7 additions & 8 deletions egs2/iwslt14/mt1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ tgt_lang=en

train_set=train
train_dev=valid
test_set="test"
test_set="test valid"

mt_config=conf/train_mt_transformer.yaml
inference_config=conf/decode_mt.yaml

src_nbpe=1000
tgt_nbpe=1000
tgt_nbpe=10000 # if token_joint is True, then only tgt_nbpe is used

# tc: truecase
# lc: lowercase
Expand All @@ -27,12 +27,11 @@ tgt_case=tc

./mt.sh \
--ignore_init_mismatch true \
--stage 1 \
--stop_stage 13 \
--use_lm false \
--token_joint false \
--nj 20 \
--inference_nj 20 \
--token_joint true \
--ngpu 1 \
--nj 16 \
--inference_nj 32 \
--src_lang ${src_lang} \
--tgt_lang ${tgt_lang} \
--src_token_type "bpe" \
Expand All @@ -49,4 +48,4 @@ tgt_case=tc
--test_sets "${test_set}" \
--src_bpe_train_text "data/${train_set}/text.${src_case}.${src_lang}" \
--tgt_bpe_train_text "data/${train_set}/text.${tgt_case}.${tgt_lang}" \
--lm_train_text "data/${train_set}/text.${tgt_case}.${tgt_lang}" "$@"
--lm_train_text "data/${train_set}/text.${tgt_case}.${tgt_lang}" "$@"
25 changes: 25 additions & 0 deletions espnet2/mt/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
share_decoder_input_output_embed: bool = False,
share_encoder_decoder_input_embed: bool = False,
):
assert check_argument_types()

Expand All @@ -66,6 +68,29 @@ def __init__(
self.ignore_id = ignore_id
self.token_list = token_list.copy()

if share_decoder_input_output_embed:
if decoder.output_layer is not None:
decoder.output_layer.weight = decoder.embed[0].weight
logging.info(
"Decoder input embedding and output linear layer are shared"
)
else:
logging.warning(
"Decoder has no output layer, so it cannot be shared "
"with input embedding"
)

if share_encoder_decoder_input_embed:
if src_vocab_size == vocab_size:
frontend.embed[0].weight = decoder.embed[0].weight
logging.info("Encoder and decoder input embeddings are shared")
else:
logging.warning(
f"src_vocab_size ({src_vocab_size}) does not equal tgt_vocab_size"
f" ({vocab_size}), so the encoder and decoder input embeddings "
"cannot be shared"
)

self.frontend = frontend
self.preencoder = preencoder
self.postencoder = postencoder
Expand Down

0 comments on commit bfb23b8

Please sign in to comment.