diff --git a/scripts/machine_translation/index.rst b/scripts/machine_translation/index.rst index a228ee24ed..9824909677 100644 --- a/scripts/machine_translation/index.rst +++ b/scripts/machine_translation/index.rst @@ -10,7 +10,7 @@ Use the following command to train the GNMT model on the IWSLT2015 dataset. .. code-block:: console - $ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt.py --src_lang en --tgt_lang vi --batch_size 128 \ + $ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt_estimator.py --src_lang en --tgt_lang vi --batch_size 128 \ --optimizer adam --lr 0.001 --lr_update_factor 0.5 --beam_size 10 --bucket_scheme exp \ --num_hidden 512 --save_dir gnmt_en_vi_l2_h512_beam10 --epochs 12 --gpu 0 @@ -23,7 +23,7 @@ Use the following commands to train the Transformer model on the WMT14 dataset f .. code-block:: console - $ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer.py --dataset WMT2014BPE \ + $ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer_estimator.py --dataset WMT2014BPE \ --src_lang en --tgt_lang de --batch_size 2700 \ --optimizer adam --num_accumulated 16 --lr 2.0 --warmup_steps 4000 \ --save_dir transformer_en_de_u512 --epochs 30 --gpus 0,1,2,3,4,5,6,7 --scaled \ diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py new file mode 100644 index 0000000000..8a4f152bb0 --- /dev/null +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -0,0 +1,212 @@ +""" +Google Neural Machine Translation +================================= + +This example shows how to implement the GNMT model with Gluon NLP Toolkit. + +@article{wu2016google, + title={Google's neural machine translation system: + Bridging the gap between human and machine translation}, + author={Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and + Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and + Macherey, Klaus and others}, + journal={arXiv preprint arXiv:1609.08144}, + year={2016} +} +""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation,unexpected-keyword-arg + +import argparse +import random +import os +import logging +import numpy as np +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.contrib.estimator import LoggingHandler, ValidationHandler + +import gluonnlp as nlp +from gluonnlp.model.translation import NMTModel +from gluonnlp.loss import MaskedSoftmaxCELoss +from gluonnlp.metric import LengthNormalizedLoss +from gluonnlp.estimator import MachineTranslationEstimator +from gluonnlp.estimator import MTGNMTBatchProcessor, MTGNMTGradientUpdateHandler +from gluonnlp.estimator import ComputeBleuHandler, ValBleuHandler +from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler +from gluonnlp.estimator import MTCheckpointHandler + +from gnmt import get_gnmt_encoder_decoder +from translation import BeamSearchTranslator +from utils import logging_config +from bleu import compute_bleu +import dataprocessor + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +nlp.utils.check_version('0.9.0') + +parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' + 'We train the Google NMT model') +parser.add_argument('--dataset', type=str, default='IWSLT2015', help='Dataset to use.') +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='vi', help='Target language') +parser.add_argument('--epochs', type=int, default=40, help='upper epoch limit') +parser.add_argument('--num_hidden', type=int, default=128, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--num_layers', type=int, default=2, help='number of layers in the encoder' + ' and decoder') +parser.add_argument('--num_bi_layers', type=int, default=1, + help='number of bidirectional layers in the encoder and decoder') +parser.add_argument('--batch_size', type=int, default=128, help='Batch size') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=1.0, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=32, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=5, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=50, help='Maximum length of the source ' + 'sentence') +parser.add_argument('--tgt_max_len', type=int, default=50, help='Maximum length of the target ' + 'sentence') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--lr', type=float, default=1E-3, help='Initial learning rate') +parser.add_argument('--lr_update_factor', type=float, default=0.5, + help='Learning rate decay factor') +parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') +parser.add_argument('--gpu', type=int, default=None, + help='id of the gpu to use. Set it to empty means to use cpu.') +args = parser.parse_args() +print(args) +logging_config(args.save_dir) + + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab\ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu='tweaked', args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) +if args.gpu is None: + ctx = mx.cpu() + print('Use CPU') +else: + ctx = mx.gpu(args.gpu) + +encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder( + hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers, + num_bi_layers=args.num_bi_layers) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden, + prefix='gnmt_') +model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=args.tgt_max_len + 100) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + + +loss_function = MaskedSoftmaxCELoss() +loss_function.hybridize(static_alloc=static_alloc) +trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) + +train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args) + +train_metric = LengthNormalizedLoss(loss_function) +val_metric = LengthNormalizedLoss(loss_function) +batch_processor = MTGNMTBatchProcessor() +gnmt_estimator = MachineTranslationEstimator(net=model, loss=loss_function, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, + context=ctx, + batch_processor=batch_processor) + +learning_rate_handler = MTGNMTLearningRateHandler(epochs=args.epochs, + lr_update_factor=args.lr_update_factor) + +gradient_update_handler = MTGNMTGradientUpdateHandler(clip=args.clip) + +metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.train_metrics, + grad_interval=1) + +bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') + +test_bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=test_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') + +val_bleu_handler = ValBleuHandler(val_data=val_data_loader, + val_tgt_vocab=tgt_vocab, val_tgt_sentences=val_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') + +checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir) + +val_metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.val_metrics) + +val_validation_handler = ValidationHandler(val_data=val_data_loader, + eval_fn=gnmt_estimator.evaluate, + event_handlers=val_metric_handler) + +logging_handler = LoggingHandler(log_interval=args.log_interval, + metrics=gnmt_estimator.train_metrics) + +event_handlers = [learning_rate_handler, gradient_update_handler, metric_handler, + val_bleu_handler, checkpoint_handler, val_validation_handler, logging_handler] + +gnmt_estimator.fit(train_data=train_data_loader, + val_data=val_data_loader, + epochs=args.epochs, + event_handlers=event_handlers, + batch_axis=0) + +val_event_handlers = [val_metric_handler, bleu_handler] +test_event_handlers = [val_metric_handler, test_bleu_handler] + +gnmt_estimator.evaluate(val_data=val_data_loader, event_handlers=val_event_handlers) +gnmt_estimator.evaluate(val_data=test_data_loader, event_handlers=test_event_handlers) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py new file mode 100644 index 0000000000..ebea489773 --- /dev/null +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -0,0 +1,322 @@ +""" +Transformer +================================= + +This example shows how to implement the Transformer model with Gluon NLP Toolkit. + +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, + Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, + booktitle={Advances in Neural Information Processing Systems}, + pages={6000--6010}, + year={2017} +} +""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation,unexpected-keyword-arg + +import argparse +import logging +import os +import random + +import numpy as np +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.contrib.estimator import ValidationHandler + +import gluonnlp as nlp +from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss +from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder +from gluonnlp.model.translation import NMTModel +from gluonnlp.metric import LengthNormalizedLoss +from gluonnlp.estimator import MachineTranslationEstimator +from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler +from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler +from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler +from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler +from gluonnlp.estimator import MTTransformerLoggingHandler + +import dataprocessor +from bleu import _bpe_to_words, compute_bleu +from translation import BeamSearchTranslator +from utils import logging_config + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +nlp.utils.check_version('0.9.0') + +parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description='Neural Machine Translation Example with the Transformer Model.') +parser.add_argument('--dataset', type=str.upper, default='WMT2016BPE', help='Dataset to use.', + choices=['IWSLT2015', 'WMT2016BPE', 'WMT2014BPE', 'TOY']) +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') +parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') +parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--hidden_size', type=int, default=2048, + help='Dimension of the hidden state in position-wise feed-forward networks.') +parser.add_argument('--dropout', type=float, default=0.1, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--epsilon', type=float, default=0.1, + help='epsilon parameter for label smoothing') +parser.add_argument('--num_layers', type=int, default=6, + help='number of layers in the encoder and decoder') +parser.add_argument('--num_heads', type=int, default=8, + help='number of heads in multi-head attention') +parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') +parser.add_argument('--batch_size', type=int, default=1024, + help='Batch size. Number of tokens per gpu in a minibatch') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=0.6, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source ' + 'sentence, -1 means no clipping') +parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target ' + 'sentence, -1 means no clipping') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate') +parser.add_argument('--warmup_steps', type=float, default=4000, + help='number of warmup steps used in NOAM\'s stepsize schedule') +parser.add_argument('--num_accumulated', type=int, default=1, + help='Number of steps to accumulate the gradients. ' + 'This is useful to mimic large batch training with limited gpu memory') +parser.add_argument('--magnitude', type=float, default=3.0, + help='Magnitude of Xavier initialization') +parser.add_argument('--average_checkpoint', action='store_true', + help='Turn on to perform final testing based on ' + 'the average of last few checkpoints') +parser.add_argument('--num_averages', type=int, default=5, + help='Perform final testing based on the ' + 'average of last num_averages checkpoints. ' + 'This is only used if average_checkpoint is True') +parser.add_argument('--average_start', type=int, default=5, + help='Perform average SGD on last average_start epochs') +parser.add_argument('--full', action='store_true', + help='In default, we use the test dataset in' + ' http://statmt.org/wmt14/test-filtered.tgz.' + ' When the option full is turned on, we use the test dataset in' + ' http://statmt.org/wmt14/test-full.tgz') +parser.add_argument('--bleu', type=str, default='tweaked', + help='Schemes for computing bleu score. It can be: ' + '"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor ' + 'repository, where compound words are put in ATAT format; ' + '"13a": This uses official WMT tokenization and produces the same results' + ' as official script (mteval-v13a.pl) used by WMT; ' + '"intl": This use international tokenization in mteval-v14a.pl') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='transformer_out', + help='directory path to save the final model and training log') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.' + '(using single gpu is suggested)') +args = parser.parse_args() +logging_config(args.save_dir) +logging.info(args) + + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) + +ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ + [mx.gpu(int(x)) for x in args.gpus.split(',')] +num_ctxs = len(ctx) + +data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x) + for x in + [data_train, data_val, data_test]] + +if args.src_max_len <= 0 or args.tgt_max_len <= 0: + max_len = np.max( + [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0), + np.max(data_test_lengths, axis=0)], + axis=0) +if args.src_max_len > 0: + src_max_len = args.src_max_len +else: + src_max_len = max_len[0] +if args.tgt_max_len > 0: + tgt_max_len = args.tgt_max_len +else: + tgt_max_len = max_len[1] +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, + share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units, + tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None, + prefix='transformer_') +model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=200) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + +label_smoothing = LabelSmoothing(epsilon=args.epsilon, units=len(tgt_vocab)) +label_smoothing.hybridize(static_alloc=static_alloc) + +loss_function = MaskedSoftmaxCELoss(sparse_label=False) +loss_function.hybridize(static_alloc=static_alloc) + +test_loss_function = MaskedSoftmaxCELoss() +test_loss_function.hybridize(static_alloc=static_alloc) + +rescale_loss = 100. +parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) +detokenizer = nlp.data.SacreMosesDetokenizer() + +trainer = gluon.Trainer(model.collect_params(), args.optimizer, + {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}) + +train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args, + use_average_length=True, num_shards=len(ctx)) + +if args.bleu == 'tweaked': + bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') + split_compound_word = bpe + tokenized = True +elif args.bleu == '13a' or args.bleu == 'intl': + bpe = False + split_compound_word = False + tokenized = False +else: + raise NotImplementedError + +grad_interval = args.num_accumulated +average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) + +train_metric = LengthNormalizedLoss(loss_function) +val_metric = LengthNormalizedLoss(test_loss_function) +batch_processor = MTTransformerBatchProcessor(rescale_loss=rescale_loss, + batch_size=args.batch_size, + label_smoothing=label_smoothing, + loss_function=loss_function) + +mt_estimator = MachineTranslationEstimator(net=model, loss=loss_function, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, + context=ctx, + val_loss=test_loss_function, + batch_processor=batch_processor) + +param_update_handler = MTTransformerParamUpdateHandler(avg_start=average_start, + grad_interval=grad_interval) +learning_rate_handler = TransformerLearningRateHandler(lr=args.lr, num_units=args.num_units, + warmup_steps=args.warmup_steps, + grad_interval=grad_interval) +gradient_acc_handler = TransformerGradientAccumulationHandler(grad_interval=grad_interval, + batch_size=args.batch_size, + rescale_loss=rescale_loss) +metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.train_metrics, + grad_interval=grad_interval) +bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe, bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) + +test_bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=test_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe, bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) + +val_bleu_handler = ValBleuHandler(val_data=val_data_loader, val_tgt_vocab=tgt_vocab, + val_tgt_sentences=val_tgt_sentences, translator=translator, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, bpe=bpe, + compute_bleu_fn=compute_bleu, + bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) + +checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir, + average_checkpoint=args.average_checkpoint, + num_averages=args.num_averages, + average_start=args.average_start) + +val_metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.val_metrics) + +val_validation_handler = ValidationHandler(val_data=val_data_loader, + eval_fn=mt_estimator.evaluate, + event_handlers=val_metric_handler) + +log_interval = args.log_interval * grad_interval +logging_handler = MTTransformerLoggingHandler(log_interval=log_interval, + metrics=mt_estimator.train_metrics) + +event_handlers = [param_update_handler, + learning_rate_handler, + gradient_acc_handler, + metric_handler, + val_validation_handler, + val_bleu_handler, + checkpoint_handler, + logging_handler] + +mt_estimator.fit(train_data=train_data_loader, + val_data=val_data_loader, + epochs=args.epochs, + event_handlers=event_handlers, + batch_axis=0) + +val_event_handlers = [val_metric_handler, + bleu_handler] + +test_event_handlers = [val_metric_handler, + test_bleu_handler] + +mt_estimator.evaluate(val_data=val_data_loader, event_handlers=val_event_handlers) + +mt_estimator.evaluate(val_data=test_data_loader, event_handlers=test_event_handlers) diff --git a/src/gluonnlp/__init__.py b/src/gluonnlp/__init__.py index 7a588e8233..f9772b95fc 100644 --- a/src/gluonnlp/__init__.py +++ b/src/gluonnlp/__init__.py @@ -30,6 +30,7 @@ from . import vocab from . import optimizer from . import initializer +from . import estimator from .vocab import Vocab __version__ = '0.10.0.dev' @@ -43,7 +44,8 @@ 'initializer', 'optimizer', 'utils', - 'metric'] + 'metric', + 'estimator'] warnings.filterwarnings(module='gluonnlp', action='default', category=DeprecationWarning) utils.version.check_version('1.6.0', warning_only=True, library=mxnet) diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py new file mode 100644 index 0000000000..1672dff82b --- /dev/null +++ b/src/gluonnlp/estimator/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name + +""" Gluon NLP Estimator Module """ +from . import machine_translation_estimator, machine_translation_event_handler +from . import machine_translation_batch_processor + +from .machine_translation_estimator import * +from .machine_translation_event_handler import * +from .machine_translation_batch_processor import * + +__all__ = (machine_translation_estimator.__all__ + machine_translation_event_handler.__all__ + + machine_translation_batch_processor.__all__) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py new file mode 100644 index 0000000000..89cb81ffae --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Batch Processor """ + +import numpy as np +import mxnet as mx +from mxnet.gluon.contrib.estimator import BatchProcessor +from ..model.transformer import ParallelTransformer +from ..utils.parallel import Parallel + +__all__ = ['MTTransformerBatchProcessor', 'MTGNMTBatchProcessor'] + +class MTTransformerBatchProcessor(BatchProcessor): + """Batch processor for transformer training on Machine translation + + The batch training and validation procedure on transformer network + + Parameters + ---------- + rescale_loss : int + normalization constant for loss computation + batch_size : int + number of tokens per gpu in a minibatch + label_smoothing : HybridBlock + Apply label smoothing on the given network + loss_function : mxnet.gluon.loss + training loss function + """ + def __init__(self, rescale_loss=100, + batch_size=1024, + label_smoothing=None, + loss_function=None): + super(MTTransformerBatchProcessor, self).__init__() + self.rescale_loss = rescale_loss + self.batch_size = batch_size + self.label_smoothing = label_smoothing + self.loss_function = loss_function + self.parallel_model = None + + def _get_parallel_model(self, estimator): + if self.label_smoothing is None or self.loss_function is None: + raise ValueError('label smoothing or loss function cannot be none.') + if self.parallel_model is None: + self.parallel_model = ParallelTransformer(estimator.net, self.label_smoothing, + self.loss_function, self.rescale_loss) + self.parallel_model = Parallel(len(estimator.context), self.parallel_model) + + def fit_batch(self, estimator, train_batch, batch_axis=0): + self._get_parallel_model(estimator) + data = [shard[0] for shard in train_batch] + target = [shard[1] for shard in train_batch] + _, tgt_word_count, bs = np.sum([(shard[2].sum(), + shard[3].sum(), + shard[0].shape[0]) + for shard in + train_batch], + axis=0) + estimator.tgt_valid_length = tgt_word_count.asscalar() - bs + seqs = [[seq.as_in_context(context) for seq in shard] + for context, shard in zip(estimator.context, train_batch)] + Ls = [] + for seq in seqs: + self.parallel_model.put((seq, self.batch_size)) + Ls = [self.parallel_model.get() for _ in range(len(estimator.context))] + Ls = [l * self.batch_size * self.rescale_loss for l in Ls] + return data, [target, tgt_word_count - bs], None, Ls + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + ctx = estimator.context[0] + src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = val_batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) + + out, _ = estimator.val_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = estimator.val_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).sum().asscalar() + inst_ids = inst_ids.asnumpy().astype(np.int32).tolist() + loss = loss * (tgt_seq.shape[1] - 1) + val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() + return src_seq, [tgt_seq, val_tgt_valid_length], out, loss + +class MTGNMTBatchProcessor(BatchProcessor): + """Batch processor for GNMT training + + Batch training and validation on the GNMT network for the machine translation task. + """ + def __init__(self): + super(MTGNMTBatchProcessor, self).__init__() + + def fit_batch(self, estimator, train_batch, batch_axis=0): + ctx = estimator.context[0] + src_seq, tgt_seq, src_valid_length, tgt_valid_length = train_batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) + with mx.autograd.record(): + out, _ = estimator.net(src_seq, tgt_seq[:, :-1], src_valid_length, + tgt_valid_length - 1) + loss = estimator.loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() + loss = loss * (tgt_seq.shape[1] - 1) + log_loss = loss * tgt_seq.shape[0] + loss = loss / (tgt_valid_length - 1).mean() + loss.backward() + return src_seq, [tgt_seq, (tgt_valid_length - 1).sum()], out, log_loss + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + ctx = estimator.context[0] + src_seq, tgt_seq, src_valid_length, tgt_valid_length, _ = val_batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) + out, _ = estimator.val_net(src_seq, tgt_seq[:, :-1], src_valid_length, + tgt_valid_length - 1) + loss = estimator.val_loss(out, tgt_seq[:, 1:], + tgt_valid_length - 1).sum().asscalar() + loss = loss * (tgt_seq.shape[1] - 1) + val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() + return src_seq, [tgt_seq, val_tgt_valid_length], out, loss diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py new file mode 100644 index 0000000000..334d6061dd --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Estimator """ + +from mxnet.gluon.contrib.estimator import Estimator +from .machine_translation_batch_processor import MTTransformerBatchProcessor + +__all__ = ['MachineTranslationEstimator'] + +class MachineTranslationEstimator(Estimator): + """Estimator class for machine translation tasks + + Facilitates training and validation on machine translation tasks + Parameters + ---------- + net : gluon.Block + The model used for training. + loss : gluon.loss.Loss + Loss (objective) function to calculate during training. + train_metrics : EvalMetric or list of EvalMetric + Training metrics for evaluating models on training dataset. + val_metrics : EvalMetric or list of EvalMetric + Validation metrics for evaluating models on validation dataset. + initializer : Initializer + Initializer to initialize the network. + trainer : Trainer + Trainer to apply optimizer on network parameters. + context : Context or list of Context + Device(s) to run the training on. + val_net : gluon.Block + The model used for validation. The validation model does not necessarily belong to + the same model class as the training model. + val_loss : gluon.loss.loss + Loss (objective) function to calculate during validation. If set val_loss + None, it will use the same loss function as self.loss + batch_processor: BatchProcessor + BatchProcessor provides customized fit_batch() and evaluate_batch() methods + """ + def __init__(self, net, loss, + train_metrics=None, + val_metrics=None, + initializer=None, + trainer=None, + context=None, + val_loss=None, + val_net=None, + batch_processor=MTTransformerBatchProcessor()): + super().__init__(net=net, loss=loss, + train_metrics=train_metrics, + val_metrics=val_metrics, + initializer=initializer, + trainer=trainer, + context=context, + val_loss=val_loss, + val_net=val_net, + batch_processor=batch_processor) + self.tgt_valid_length = 0 + self.val_tgt_valid_length = 0 + self.avg_param = None + self.bleu_score = 0.0 diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py new file mode 100644 index 0000000000..8885f066c9 --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -0,0 +1,481 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Event Handler """ + +import math +import os +import time + +import numpy as np +import mxnet as mx +from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin +from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd +from mxnet.gluon.contrib.estimator import GradientUpdateHandler, CheckpointHandler +from mxnet.gluon.contrib.estimator import MetricHandler, LoggingHandler +from mxnet import gluon +from mxnet.metric import Loss as MetricLoss +from ..metric.length_normalized_loss import LengthNormalizedLoss + +__all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', + 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', + 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler', + 'MTGNMTLearningRateHandler', 'MTCheckpointHandler', + 'MTTransformerLoggingHandler'] + +class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): + """Transformer average parameter update handler + + Update weighted average parameters of the transformer during training + + Parameters + ---------- + avg_start : int + the starting epoch of performing average sgd update + grad_interval : int + The interval of update avarege model parameters + """ + def __init__(self, avg_start, grad_interval=1): + self.batch_id = 0 + self.grad_interval = grad_interval + self.step_num = 0 + self.avg_start = avg_start + + def _update_avg_param(self, estimator): + if estimator.avg_param is None: + estimator.avg_param = {k:v.data(estimator.context[0]).copy() for k, v in + estimator.net.collect_params().items()} + if self.step_num > self.avg_start: + params = estimator.net.collect_params() + alpha = 1. / max(1, self.step_num - self.avg_start) + for key, val in estimator.avg_param.items(): + val[:] += alpha * (params[key].data(estimator.context[0]) - val) + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + + def batch_end(self, estimator, *args, **kwargs): + if self.batch_id % self.grad_interval == 0: + self.step_num += 1 + if self.batch_id % self.grad_interval == self.grad_interval - 1: + self._update_avg_param(estimator) + self.batch_id += 1 + + def epoch_end(self, estimator, *args, **kwargs): + self._update_avg_param(estimator) + +class MTGNMTLearningRateHandler(EpochEnd): + """GNMT learning rate update handler + + dynamically adjust the learning rate during GNMT training + + Parameters + ---------- + epochs : int + total number of epoches for GNMT training + lr_update_factor : float + the decaying factor of learning rate + """ + def __init__(self, epochs, lr_update_factor): + self.epoch_id = 0 + self.epochs = epochs + self.lr_update_factor = lr_update_factor + + def epoch_end(self, estimator, *args, **kwargs): + if self.epoch_id + 1 >= (self.epochs * 2) // 3: + new_lr = estimator.trainer.learning_rate * self.lr_update_factor + estimator.trainer.set_learning_rate(new_lr) + self.epoch_id += 1 + +class TransformerLearningRateHandler(EpochBegin, BatchBegin): + """Transformer learning rate update handler + + dynamically adjust the learning rate during transformer training + + Parameters + ---------- + lr : float + initial learning rate for transformer training + num_units : int + dimension of the embedding vector + warmup_steps : int + number of warmup steps used in training schedule + grad_interval : int + the interval of updating learning rate + """ + def __init__(self, lr, + num_units=512, + warmup_steps=4000, + grad_interval=1): + self.lr = lr + self.num_units = num_units + self.warmup_steps = warmup_steps + self.grad_interval = grad_interval + self.step_num = 0 + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + + def batch_begin(self, estimator, *args, **kwargs): + if self.batch_id % self.grad_interval == 0: + self.step_num += 1 + new_lr = self.lr / math.sqrt(self.num_units) * \ + min(1. / math.sqrt(self.step_num), self.step_num * + self.warmup_steps ** (-1.5)) + estimator.trainer.set_learning_rate(new_lr) + self.batch_id += 1 + +class MTGNMTGradientUpdateHandler(GradientUpdateHandler): + """Gradient update handler of GNMT training + + clip gradient if gradient norm exceeds some threshold during GNMT training + + Parameters + ---------- + clip : float + gradient norm threshold. If gradient norm exceeds this value, it should be + scaled down to the valid range. + """ + def __init__(self, clip): + super(MTGNMTGradientUpdateHandler, self).__init__() + self.clip = clip + + def batch_end(self, estimator, *args, **kwargs): + grads = [p.grad(estimator.context[0]) + for p in estimator.net.collect_params().values()] + gluon.utils.clip_global_norm(grads, self.clip) + estimator.trainer.step(1) + +class TransformerGradientAccumulationHandler(GradientUpdateHandler, + TrainBegin, + EpochBegin, + EpochEnd): + """Gradient accumulation handler for transformer training + + Accumulates gradients of the network for a few iterations, and updates + network parameters with the accumulated gradients + + Parameters + ---------- + grad_interval : int + the interval of updating gradients + batch_size : int + number of tokens per gpu in a minibatch + rescale_loss : float + normalization constant + """ + def __init__(self, grad_interval=1, + batch_size=1024, + rescale_loss=100): + super(TransformerGradientAccumulationHandler, self).__init__() + self.grad_interval = grad_interval + self.batch_size = batch_size + self.rescale_loss = rescale_loss + + def _update_gradient(self, estimator): + estimator.trainer.step(float(self.loss_denom) / + self.batch_size /self.rescale_loss) + params = estimator.net.collect_params() + params.zero_grad() + self.loss_denom = 0 + + def train_begin(self, estimator, *args, **kwargs): + params = estimator.net.collect_params() + params.setattr('grad_req', 'add') + params.zero_grad() + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + self.loss_denom = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.loss_denom += estimator.tgt_valid_length + if self.batch_id % self.grad_interval == self.grad_interval - 1: + self._update_gradient(estimator) + self.batch_id += 1 + + def epoch_end(self, estimator, *args, **kwargs): + if self.loss_denom > 0: + self._update_gradient(estimator) + +class MTTransformerMetricHandler(MetricHandler, BatchBegin): + """Metric update handler for transformer training + + Reset the local metric stats for every few iterations and include the LengthNormalizedLoss + for metrics update + TODO : Refactor this event handler and share it with other estimators + + Parameters + ---------- + grad_interval : int + interval of resetting local metrics during transformer training + """ + def __init__(self, *args, grad_interval=None, **kwargs): + super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) + self.grad_interval = grad_interval + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + for metric in self.metrics: + metric.reset() + + def batch_begin(self, estimator, *args, **kwargs): + if self.grad_interval is not None and self.batch_id % self.grad_interval == 0: + for metric in self.metrics: + metric.reset_local() + self.batch_id += 1 + + def batch_end(self, estimator, *args, **kwargs): + pred = kwargs['pred'] + label = kwargs['label'] + loss = kwargs['loss'] + for metric in self.metrics: + if isinstance(metric, MetricLoss): + metric.update(0, loss) + elif isinstance(metric, LengthNormalizedLoss): + metric.update(label, loss) + else: + metric.update(label, pred) + +class MTCheckpointHandler(CheckpointHandler, TrainEnd): + """Checkpoint handler for machine translation tasks training + + save model parameter checkpoint and average parameter checkpoint during transformer + or GNMT training + + Parameters + ---------- + average_checkpoint : bool + whether store the average parameters of last few iterations + num_averages : int + number of last few model checkpoints to be averaged + average_start : int + performing average sgd on last average_start epochs + epochs : int + total epochs of machine translation model training + """ + def __init__(self, *args, + average_checkpoint=None, + num_averages=None, + average_start=0, + epochs=0, + **kwargs): + super(MTCheckpointHandler, self).__init__(*args, **kwargs) + self.bleu_score = 0. + self.average_checkpoint = average_checkpoint + self.num_averages = num_averages + self.average_start = average_start + self.epochs = epochs + + def epoch_end(self, estimator, *args, **kwargs): + if estimator.bleu_score > self.bleu_score: + self.bleu_score = estimator.bleu_score + save_path = os.path.join(self.model_dir, 'valid_best.params') + estimator.net.save_parameters(save_path) + save_path = os.path.join(self.model_dir, 'epoch{:d}.params'.format(self.current_epoch)) + estimator.net.save_parameters(save_path) + self.current_epoch += 1 + + def train_end(self, estimator, *args, **kwargs): + ctx = estimator.context + if estimator.avg_param is not None: + save_path = os.path.join(self.model_dir, 'average.params') + mx.nd.save(save_path, estimator.avg_param) + if self.average_checkpoint: + for j in range(self.num_averages): + params = mx.nd.load(os.path.join(self.model_dir, + 'epoch{:d}.params'.format(self.epochs - j - 1))) + alpha = 1. / (j + 1) + for k, v in estimator.net._collect_params_with_prefix().items(): + for c in ctx: + v.data(c)[:] = alpha * (params[k].as_in_context(c) - v.data(c)) + save_path = os.path.join(self.model_dir, + 'average_checkpoint_{}.params'.format(self.num_averages)) + estimator.net.save_parameters(save_path) + elif self.average_start: + for k, v in estimator.net.collect_params().items(): + v.set_data(estimator.avg_param[k]) + save_path = os.path.join(self.model_dir, 'average.params') + estimator.net.save_parameters(save_path) + else: + estimator.net.load_parameters(os.path.join(self.model_dir, + 'valid_best.params'), ctx) + + +class ComputeBleuHandler(BatchEnd, EpochEnd): + """Bleu score computation handler + + this event handler serves as a temporary workaround for computing Bleu score for + estimator training. + TODO: please remove this event handler after bleu metrics is merged to api + """ + def __init__(self, + tgt_vocab, + tgt_sentence, + translator, + compute_bleu_fn, + tokenized=True, + tokenizer='13a', + split_compound_word=False, + bpe=False, + bleu='13a', + detokenizer=None, + _bpe_to_words=None): + self.tgt_vocab = tgt_vocab + self.tgt_sentence = tgt_sentence + self.translator = translator + self.compute_bleu_fn = compute_bleu_fn + self.tokenized = tokenized + self.tokenizer = tokenizer + self.split_compound_word = split_compound_word + self.bpe = bpe + self.bleu = bleu + self.detokenizer = detokenizer + self._bpe_to_words = _bpe_to_words + + self.all_inst_ids = [] + self.translation_out = [] + + def batch_end(self, estimator, *args, **kwargs): + ctx = estimator.context[0] + batch = kwargs['batch'] + src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) + self.all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + samples, _, sample_valid_length = self.translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + self.translation_out.append( + [self.tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + + def epoch_end(self, estimator, *args, **kwargs): + real_translation_out = [None for _ in range(len(self.all_inst_ids))] + for ind, sentence in zip(self.all_inst_ids, self.translation_out): + if self.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif self.bleu == '13a' or self.bleu == 'intl': + real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) + else: + raise NotImplementedError + estimator.bleu_score, _, _, _, _ = \ + self.compute_bleu_fn([self.tgt_sentence], + real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word=self.split_compound_word, + bpe=self.bpe) + print(estimator.bleu_score) + +class ValBleuHandler(EpochEnd): + """Handler of validation Bleu score computation + + This handler is similar to the ComputeBleuHandler. It computes the Bleu score on the + validation dataset + TODO: please remove this event handler after bleu metric is available in the api + """ + def __init__(self, val_data, + val_tgt_vocab, + val_tgt_sentences, + translator, + compute_bleu_fn, + tokenized=True, + tokenizer='13a', + split_compound_word=False, + bpe=False, + bleu='13a', + detokenizer=None, + _bpe_to_words=None): + self.val_data = val_data + self.val_tgt_vocab = val_tgt_vocab + self.val_tgt_sentences = val_tgt_sentences + self.translator = translator + self.tokenized = tokenized + self.tokenizer = tokenizer + self.split_compound_word = split_compound_word + self.bpe = bpe + self.compute_bleu_fn = compute_bleu_fn + self.bleu = bleu + self.detokenizer = detokenizer + self._bpe_to_words = _bpe_to_words + + def epoch_end(self, estimator, *args, **kwargs): + translation_out = [] + all_inst_ids = [] + for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ + in enumerate(self.val_data): + src_seq = src_seq.as_in_context(estimator.context[0]) + tgt_seq = tgt_seq.as_in_context(estimator.context[0]) + src_valid_length = src_valid_length.as_in_context(estimator.context[0]) + tgt_valid_length = tgt_valid_length.as_in_context(estimator.context[0]) + all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + samples, _, sample_valid_length = self.translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + translation_out.append( + [self.val_tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + real_translation_out = [None for _ in range(len(all_inst_ids))] + for ind, sentence in zip(all_inst_ids, translation_out): + if self.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif self.bleu == '13a' or self.bleu == 'intl': + real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) + else: + raise NotImplementedError + estimator.bleu_score, _, _, _, _ = \ + self.compute_bleu_fn([self.val_tgt_sentences], + real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word=self.split_compound_word, + bpe=self.bpe) + print(estimator.bleu_score) + +class MTTransformerLoggingHandler(LoggingHandler): + """Logging handler for transformer training + + Logging the training metrics for transformer training. This handler is introduced + due to batch cannot be handled by default LoggingHandler + """ + def __init__(self, *args, **kwargs): + super(MTTransformerLoggingHandler, self).__init__(*args, **kwargs) + + def batch_end(self, estimator, *args, **kwargs): + if isinstance(self.log_interval, int): + batch_time = time.time() - self.batch_start + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) + cur_batches = kwargs['batch'] + for batch in cur_batches: + self.processed_samples += batch[0].shape[0] + msg += '[Samples %s]' % (self.processed_samples) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.metrics: + name, val = metric.get() + msg += '%s: %.4f, ' % (name, val) + estimator.logger.info(msg.rstrip(', ')) + self.batch_index += 1