diff --git a/docs/code/losses.rst b/docs/code/losses.rst index df1a14a3..87a0c6b0 100644 --- a/docs/code/losses.rst +++ b/docs/code/losses.rst @@ -68,6 +68,14 @@ Entropy .. autofunction:: texar.losses.sequence_entropy_with_logits +DEBLEU +================== + +:hidden:`debleu` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.losses.debleu + + Loss Utils =========== diff --git a/docs/code/modules.rst b/docs/code/modules.rst index d3d6443e..5aac39c0 100644 --- a/docs/code/modules.rst +++ b/docs/code/modules.rst @@ -134,6 +134,11 @@ Decoders .. autoclass:: texar.modules.GumbelSoftmaxEmbeddingHelper :members: +:hidden:`TeacherMaskSoftmaxEmbeddingHelper` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.modules.TeacherMaskSoftmaxEmbeddingHelper + :members: + :hidden:`get_helper` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.modules.get_helper diff --git a/docs/code/utils.rst b/docs/code/utils.rst index 5c113c1a..726d4739 100644 --- a/docs/code/utils.rst +++ b/docs/code/utils.rst @@ -278,3 +278,21 @@ AverageRecorder ========================== .. autoclass:: texar.utils.AverageRecorder :members: + +Trigger +========================== + +:hidden:`Trigger` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.utils.Trigger + :members: + +:hidden:`ScheduledStepsTrigger` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.utils.ScheduledStepsTrigger + :members: + +:hidden:`BestEverConvergenceTrigger` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.utils.BestEverConvergenceTrigger + :members: diff --git a/examples/differentiable_expected_bleu/README.md b/examples/differentiable_expected_bleu/README.md new file mode 100644 index 00000000..5bff077d --- /dev/null +++ b/examples/differentiable_expected_bleu/README.md @@ -0,0 +1,48 @@ +# Seq2seq Model # + +This example builds an attentional seq2seq model for machine translation trained with Differentiable Expected BLEU (DEBLEU) and Teacher Mask. See https://openreview.net/pdf?id=S1x2aiRqFX for the implemented paper. + +### Dataset ### + + * iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset. + +Download the data with the following cmds: + +```bash +python prepare_data.py --data de-en +``` + +### Train the model ### + +Train the model with the following cmd: + +```bash +python differentiable_expected_bleu.py --config_model config_model_medium --config_data config_data_iwslt14_de-en --config_train config_train --expr_name iwslt14_de-en --restore_from "" --reinitialize +``` + +Here: + * `--config_model` specifies the model config. Note not to include the `.py` suffix. + * `--config_data` specifies the data config. + * `--config_train` specifies the training config. + * `--expr_name` specifies the experiment name. Used as the directory name to save and restore all information. + * `--restore_from` specifies the checkpoint path to restore from. If not specified (or an empty string is specified), the latest checkpoint in `expr_name` is restored. + * `--reinitialize` is a flag indicates whether to reinitialize the state of the optimizers before training and after annealing. Default is enabled. + +[config_model_medium.py](./config_model_medium.py) specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. + +[config_model_large.py](./config_model_large.py) specifies a seq2seq model with Luong attention, 2-layer bi-directional RNN encoder, single-layer RNN decoder, and a connector between the final state of the encoder and the initial state of the decoder. The size of this model is quite large. + +[config_data_iwslt14_de-en.py](./config_data_iwslt14_de-en.py) specifies the IWSLT'14 German-English dataset. + +[config_train.py](./config_train.py) specifies the training (including annealing) configs. + +## Results ## + +On the IWSLT'14 German-English dataset, we ran both configs for 4~5 times. Here are the average BLEU scores attained: + +| config | inference beam size | Cross-Entropy baseline | DEBLEU | improvement | +| :------------------------------------------------: | :-----------------: | :--------------------: | :----: | :---------: | +| [config_model_medium.py](./config_model_medium.py) | 1 | 26.12 | 27.40 | 1.28 | +| [config_model_medium.py](./config_model_medium.py) | 5 | 27.03 | 27.72 | 0.70 | +| [config_model_large.py](./config_model_large.py) | 1 | 25.24 | 26.47 | 1.23 | +| [config_model_large.py](./config_model_large.py) | 5 | 26.33 | 26.87 | 0.54 | diff --git a/examples/differentiable_expected_bleu/config_data_iwslt14_de-en.py b/examples/differentiable_expected_bleu/config_data_iwslt14_de-en.py new file mode 100644 index 00000000..ae3979f5 --- /dev/null +++ b/examples/differentiable_expected_bleu/config_data_iwslt14_de-en.py @@ -0,0 +1,59 @@ +source_vocab_file = 'data/iwslt14_de-en/vocab.de' +target_vocab_file = 'data/iwslt14_de-en/vocab.en' + +train_0 = { + 'batch_size': 80, + 'allow_smaller_final_batch': False, + 'source_dataset': { + "files": 'data/iwslt14_de-en/train.de', + 'vocab_file': source_vocab_file, + 'max_seq_length': 50 + }, + 'target_dataset': { + 'files': 'data/iwslt14_de-en/train.en', + 'vocab_file': target_vocab_file, + 'max_seq_length': 50 + }, +} + +train_1 = { + 'batch_size': 160, + 'allow_smaller_final_batch': False, + 'source_dataset': { + "files": 'data/iwslt14_de-en/train.de', + 'vocab_file': source_vocab_file, + 'max_seq_length': 50 + }, + 'target_dataset': { + 'files': 'data/iwslt14_de-en/train.en', + 'vocab_file': target_vocab_file, + 'max_seq_length': 50 + }, +} + + +val = { + 'batch_size': 80, + 'shuffle': False, + 'source_dataset': { + "files": 'data/iwslt14_de-en/valid.de', + 'vocab_file': source_vocab_file, + }, + 'target_dataset': { + 'files': 'data/iwslt14_de-en/valid.en', + 'vocab_file': target_vocab_file, + }, +} + +test = { + 'batch_size': 80, + 'shuffle': False, + 'source_dataset': { + "files": 'data/iwslt14_de-en/test.de', + 'vocab_file': source_vocab_file, + }, + 'target_dataset': { + 'files': 'data/iwslt14_de-en/test.en', + 'vocab_file': target_vocab_file, + }, +} diff --git a/examples/differentiable_expected_bleu/config_data_iwslt14_en-fr.py b/examples/differentiable_expected_bleu/config_data_iwslt14_en-fr.py new file mode 100644 index 00000000..4c4482f7 --- /dev/null +++ b/examples/differentiable_expected_bleu/config_data_iwslt14_en-fr.py @@ -0,0 +1,45 @@ +source_vocab_file = 'data/iwslt14_en-fr/vocab.en' +target_vocab_file = 'data/iwslt14_en-fr/vocab.fr' + +batch_size = 80 + +train = { + 'batch_size': batch_size, + 'allow_smaller_final_batch': False, + 'source_dataset': { + "files": 'data/iwslt14_en-fr/train.en', + 'vocab_file': source_vocab_file, + 'max_seq_length': 50 + }, + 'target_dataset': { + 'files': 'data/iwslt14_en-fr/train.fr', + 'vocab_file': target_vocab_file, + 'max_seq_length': 50 + }, +} + +val = { + 'batch_size': batch_size, + 'shuffle': False, + 'source_dataset': { + "files": 'data/iwslt14_en-fr/valid.en', + 'vocab_file': source_vocab_file, + }, + 'target_dataset': { + 'files': 'data/iwslt14_en-fr/valid.fr', + 'vocab_file': target_vocab_file, + }, +} + +test = { + 'batch_size': batch_size, + 'shuffle': False, + 'source_dataset': { + "files": 'data/iwslt14_en-fr/test.en', + 'vocab_file': source_vocab_file, + }, + 'target_dataset': { + 'files': 'data/iwslt14_en-fr/test.fr', + 'vocab_file': target_vocab_file, + }, +} diff --git a/examples/differentiable_expected_bleu/config_model_large.py b/examples/differentiable_expected_bleu/config_model_large.py new file mode 100644 index 00000000..16dba9b9 --- /dev/null +++ b/examples/differentiable_expected_bleu/config_model_large.py @@ -0,0 +1,39 @@ +# Attentional Seq2seq model. +# Hyperparameters not specified here will take the default values. + +num_units = 1000 +embedding_dim = 500 + +embedder = { + 'dim': embedding_dim +} + +encoder = { + 'rnn_cell_fw': { + 'kwargs': { + 'num_units': num_units + }, + 'num_layers': 2 + }, + 'output_layer_fw': { + 'dropout_rate': 0 + } +} + +connector = { + 'activation_fn': 'tanh' +} + +decoder = { + 'rnn_cell': { + 'kwargs': { + 'num_units': num_units + }, + }, + 'attention': { + 'kwargs': { + 'num_units': num_units, + }, + 'attention_layer_size': num_units + } +} diff --git a/examples/differentiable_expected_bleu/config_model_medium.py b/examples/differentiable_expected_bleu/config_model_medium.py new file mode 100644 index 00000000..7750a97c --- /dev/null +++ b/examples/differentiable_expected_bleu/config_model_medium.py @@ -0,0 +1,40 @@ +# Attentional Seq2seq model. +# Hyperparameters not specified here will take the default values. + +num_units = 256 +embedding_dim = 256 +dropout = 0.2 + +embedder = { + 'dim': embedding_dim +} + +encoder = { + 'rnn_cell_fw': { + 'kwargs': { + 'num_units': num_units + }, + 'dropout': { + 'input_keep_prob': 1. - dropout + } + } +} + +connector = None + +decoder = { + 'rnn_cell': { + 'kwargs': { + 'num_units': num_units + }, + 'dropout': { + 'input_keep_prob': 1. - dropout + } + }, + 'attention': { + 'kwargs': { + 'num_units': num_units, + }, + 'attention_layer_size': num_units + } +} diff --git a/examples/differentiable_expected_bleu/config_train.py b/examples/differentiable_expected_bleu/config_train.py new file mode 100644 index 00000000..09d3464f --- /dev/null +++ b/examples/differentiable_expected_bleu/config_train.py @@ -0,0 +1,80 @@ +max_epochs = 1000 +steps_per_eval = 500 +tau = 1. +infer_beam_width = 1 +infer_max_decoding_length = 50 + +threshold_steps = 10000 +minimum_interval_steps = 10000 +phases = [ + # (config_data, config_train, mask_pattern) + ("train_0", "xe_0", None), + ("train_0", "xe_1", None), + ("train_0", "debleu_0", (2, 2)), + ("train_1", "debleu_0", (4, 2)), + ("train_1", "debleu_1", (1, 0)), +] + +train_xe_0 = { + "optimizer": { + "type": "AdamOptimizer", + "kwargs": { + "learning_rate": 1e-3 + } + }, + "gradient_clip": { + "type": "clip_by_global_norm", + "kwargs": { + "clip_norm": 5. + } + }, + "name": "XE_0" +} + +train_xe_1 = { + "optimizer": { + "type": "AdamOptimizer", + "kwargs": { + "learning_rate": 1e-5 + } + }, + "gradient_clip": { + "type": "clip_by_global_norm", + "kwargs": { + "clip_norm": 5. + } + }, + "name": "XE_1" +} + +train_debleu_0 = { + "optimizer": { + "type": "AdamOptimizer", + "kwargs": { + "learning_rate": 1e-5 + } + }, + "gradient_clip": { + "type": "clip_by_global_norm", + "kwargs": { + "clip_norm": 5. + } + }, + "name": "DEBLEU_0" +} + +train_debleu_1 = { + "optimizer": { + "type": "AdamOptimizer", + "kwargs": { + "learning_rate": 1e-6 + } + }, + "gradient_clip": { + "type": "clip_by_global_norm", + "kwargs": { + "clip_norm": 5. + } + }, + "name": "DEBLEU_1" +} diff --git a/examples/differentiable_expected_bleu/differentiable_expected_bleu.py b/examples/differentiable_expected_bleu/differentiable_expected_bleu.py new file mode 100755 index 00000000..0c414b21 --- /dev/null +++ b/examples/differentiable_expected_bleu/differentiable_expected_bleu.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DEBLEU. +""" +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +#pylint: disable=invalid-name, too-many-arguments, too-many-locals + +import importlib +import os +import tensorflow as tf +import texar as tx + +from nltk.translate.bleu_score import corpus_bleu + +flags = tf.flags + +flags.DEFINE_string("config_model", "config_model_medium", "The model config.") +flags.DEFINE_string("config_data", "config_data_iwslt14_de-en", + "The dataset config.") +flags.DEFINE_string("config_train", "config_train", "The training config.") +flags.DEFINE_string("expr_name", "iwslt14_de-en", "The experiment name. " + "Used as the directory name of run.") +flags.DEFINE_string("restore_from", "", "The specific checkpoint path to " + "restore from. If not specified, the latest checkpoint in " + "expr_name is restored.") +flags.DEFINE_boolean("reinitialize", True, "Whether to reinitialize the state " + "of the optimizers before training and after annealing.") + +FLAGS = flags.FLAGS + +config_model = importlib.import_module(FLAGS.config_model) +config_data = importlib.import_module(FLAGS.config_data) +config_train = importlib.import_module(FLAGS.config_train) +expr_name = FLAGS.expr_name +restore_from = FLAGS.restore_from +reinitialize = FLAGS.reinitialize +phases = config_train.phases + +xe_names = ('xe_0', 'xe_1') +debleu_names = ('debleu_0', 'debleu_1') + +dir_model = os.path.join(expr_name, 'ckpt') +dir_best = os.path.join(expr_name, 'ckpt-best') +ckpt_model = os.path.join(dir_model, 'model.ckpt') +ckpt_best = os.path.join(dir_best, 'model.ckpt') + + +def get_scope_by_name(tensor): + return tensor.name[: tensor.name.rfind('/') + 1] + + +def build_model(batch, train_data): + """Assembles the seq2seq model. + """ + train_ops = {} + + source_embedder = tx.modules.WordEmbedder( + vocab_size=train_data.source_vocab.size, hparams=config_model.embedder) + + encoder = tx.modules.BidirectionalRNNEncoder( + hparams=config_model.encoder) + + enc_outputs, enc_final_state = encoder( + source_embedder(batch['source_text_ids'])) + + target_embedder = tx.modules.WordEmbedder( + vocab_size=train_data.target_vocab.size, hparams=config_model.embedder) + + decoder = tx.modules.AttentionRNNDecoder( + memory=tf.concat(enc_outputs, axis=2), + memory_sequence_length=batch['source_length'], + vocab_size=train_data.target_vocab.size, + hparams=config_model.decoder) + + if config_model.connector is None: + dec_initial_state = None + + else: + enc_final_state = tf.contrib.framework.nest.map_structure( + lambda *args: tf.concat(args, -1), *enc_final_state) + + if isinstance(decoder.cell, tf.nn.rnn_cell.LSTMCell): + connector = tx.modules.MLPTransformConnector( + decoder.state_size.h, hparams=config_model.connector) + dec_initial_h = connector(enc_final_state.h) + dec_initial_state = (dec_initial_h, enc_final_state.c) + else: + connector = tx.modules.MLPTransformConnector( + decoder.state_size, hparams=config_model.connector) + dec_initial_state = connector(enc_final_state) + + # cross-entropy + teacher-forcing pretraining + tf_outputs, _, _ = decoder( + decoding_strategy='train_greedy', + initial_state=dec_initial_state, + inputs=target_embedder(batch['target_text_ids'][:, :-1]), + sequence_length=batch['target_length']-1) + + loss_xe = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=batch['target_text_ids'][:, 1:], + logits=tf_outputs.logits, + sequence_length=batch['target_length']-1) + + train_ops[xe_names[0]] = tx.core.get_train_op( + loss_xe, + hparams=config_train.train_xe_0) + train_ops[xe_names[1]] = tx.core.get_train_op( + loss_xe, + hparams=config_train.train_xe_1) + + # teacher mask + DEBLEU fine-tuning + n_unmask = tf.placeholder(tf.int32, shape=[], name="n_unmask") + n_mask = tf.placeholder(tf.int32, shape=[], name="n_mask") + tm_helper = tx.modules.TeacherMaskSoftmaxEmbeddingHelper( + # must not remove last token, since it may be used as mask + inputs=batch['target_text_ids'], + sequence_length=batch['target_length']-1, + embedding=target_embedder, + n_unmask=n_unmask, + n_mask=n_mask, + tau=config_train.tau) + + tm_outputs, _, _ = decoder( + helper=tm_helper, + initial_state=dec_initial_state) + + loss_debleu = tx.losses.debleu( + labels=batch['target_text_ids'][:, 1:], + probs=tm_outputs.sample_id, + sequence_length=batch['target_length']-1) + + train_ops[debleu_names[0]] = tx.core.get_train_op( + loss_debleu, + hparams=config_train.train_debleu_0) + train_ops[debleu_names[1]] = tx.core.get_train_op( + loss_debleu, + hparams=config_train.train_debleu_1) + + # inference: beam search decoding + start_tokens = tf.ones_like(batch['target_length']) * \ + train_data.target_vocab.bos_token_id + end_token = train_data.target_vocab.eos_token_id + + bs_outputs, _, _ = tx.modules.beam_search_decode( + decoder_or_cell=decoder, + embedding=target_embedder, + start_tokens=start_tokens, + end_token=end_token, + initial_state=dec_initial_state, + beam_width=config_train.infer_beam_width, + max_decoding_length=config_train.infer_max_decoding_length) + + return train_ops, tm_helper, (n_unmask, n_mask), bs_outputs + + +def main(): + """Entrypoint. + """ + train_0_data = tx.data.PairedTextData(hparams=config_data.train_0) + train_1_data = tx.data.PairedTextData(hparams=config_data.train_1) + val_data = tx.data.PairedTextData(hparams=config_data.val) + test_data = tx.data.PairedTextData(hparams=config_data.test) + data_iterator = tx.data.FeedableDataIterator( + {'train_0': train_0_data, 'train_1': train_1_data, + 'val': val_data, 'test': test_data}) + data_batch = data_iterator.get_next() + + global_step = tf.train.create_global_step() + + train_ops, tm_helper, mask_pattern_, infer_outputs = build_model( + data_batch, train_0_data) + + def get_train_op_scope(name): + return get_scope_by_name(train_ops[name]) + + train_op_initializers = { + name: tf.variables_initializer( + tf.get_collection( + tf.GraphKeys.GLOBAL_VARIABLES, + scope=get_train_op_scope(name)), + name='train_{}_op_initializer'.format(name)) + for name in (xe_names + debleu_names)} + + summary_tm = [ + tf.summary.scalar('tm/n_unmask', tm_helper.n_unmask), + tf.summary.scalar('tm/n_mask', tm_helper.n_mask)] + summary_ops = { + name: tf.summary.merge( + tf.get_collection( + tf.GraphKeys.SUMMARIES, + scope=get_train_op_scope(name)) + + (summary_tm if name in debleu_names else []), + name='summary_{}'.format(name)) + for name in (xe_names + debleu_names)} + + global convergence_trigger + convergence_trigger = tx.utils.BestEverConvergenceTrigger( + None, + lambda state: state, + config_train.threshold_steps, + config_train.minimum_interval_steps) + + saver = tf.train.Saver(max_to_keep=None) + + def _save_to(directory, step): + print('saving to {} ...'.format(directory)) + saved_path = saver.save(sess, directory, global_step=step) + + for trigger_name in ['convergence_trigger', 'annealing_trigger']: + trigger = globals()[trigger_name] + trigger_path = '{}.{}'.format(saved_path, trigger_name) + print('saving {} ...'.format(trigger_name)) + with open(trigger_path, 'wb') as pickle_file: + trigger.save_to_pickle(pickle_file) + + print('saved to {}'.format(saved_path)) + + def _restore_from_path(ckpt_path, restore_trigger_names=None): + print('restoring from {} ...'.format(ckpt_path)) + saver.restore(sess, ckpt_path) + + if restore_trigger_names is None: + restore_trigger_names = ['convergence_trigger', 'annealing_trigger'] + + for trigger_name in restore_trigger_names: + trigger = globals()[trigger_name] + trigger_path = '{}.{}'.format(ckpt_path, trigger_name) + if os.path.exists(trigger_path): + print('restoring {} ...'.format(trigger_name)) + with open(trigger_path, 'rb') as pickle_file: + trigger.restore_from_pickle(pickle_file) + else: + print('cannot find previous {} state.'.format(trigger_name)) + + print('done.') + + def _restore_from(directory, restore_trigger_names=None): + if os.path.exists(directory): + ckpt_path = tf.train.latest_checkpoint(directory) + _restore_from_path(ckpt_path, restore_trigger_names) + + else: + print('cannot find checkpoint directory {}'.format(directory)) + + def _train_epoch(sess, summary_writer, mode, train_op, summary_op): + print('in _train_epoch') + + data_iterator.restart_dataset(sess, mode) + feed_dict = { + tx.global_mode(): tf.estimator.ModeKeys.TRAIN, + data_iterator.handle: data_iterator.get_handle(sess, mode), + } + if mask_pattern is not None: + feed_dict.update( + {mask_pattern_[_]: mask_pattern[_] for _ in range(2)}) + + while True: + try: + loss, summary, step = sess.run( + (train_op, summary_op, global_step), feed_dict) + + summary_writer.add_summary(summary, step) + + if step % config_train.steps_per_eval == 0: + global triggered + _eval_epoch(sess, summary_writer, 'val') + if triggered: + break + + except tf.errors.OutOfRangeError: + break + + print('end _train_epoch') + + def _eval_epoch(sess, summary_writer, mode): + print('in _eval_epoch with mode {}'.format(mode)) + + data_iterator.restart_dataset(sess, mode) + feed_dict = { + tx.global_mode(): tf.estimator.ModeKeys.EVAL, + data_iterator.handle: data_iterator.get_handle(sess, mode) + } + + ref_hypo_pairs = [] + fetches = [ + data_batch['target_text'][:, 1:], + infer_outputs.predicted_ids[:, :, 0] + ] + + while True: + try: + target_texts_ori, output_ids = sess.run(fetches, feed_dict) + target_texts = tx.utils.strip_special_tokens( + target_texts_ori.tolist(), is_token_list=True) + output_texts = tx.utils.map_ids_to_strs( + ids=output_ids.tolist(), vocab=val_data.target_vocab, + join=False) + + ref_hypo_pairs.extend( + zip(map(lambda x: [x], target_texts), output_texts)) + + except tf.errors.OutOfRangeError: + break + + refs, hypos = zip(*ref_hypo_pairs) + bleu = corpus_bleu(refs, hypos) * 100 + print('{} BLEU: {}'.format(mode, bleu)) + + step = tf.train.global_step(sess, global_step) + + summary = tf.Summary() + summary.value.add(tag='{}/BLEU'.format(mode), simple_value=bleu) + summary_writer.add_summary(summary, step) + summary_writer.flush() + + if mode == 'val': + global triggered + triggered = convergence_trigger(step, bleu) + if triggered: + print('triggered!') + + if convergence_trigger.best_ever_step == step: + print('updated best val bleu: {}'.format( + convergence_trigger.best_ever_score)) + + _save_to(ckpt_best, step) + + print('end _eval_epoch') + return bleu + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + sess.run(tf.tables_initializer()) + + def action(i): + if i >= len(phases) - 1: + return i + i += 1 + train_data_name, train_op_name, mask_pattern = phases[i] + if reinitialize: + sess.run(train_op_initializers[train_op_name]) + return i + + global annealing_trigger + annealing_trigger = tx.utils.Trigger(0, action) + + def _restore_and_anneal(): + _restore_from(dir_best, ['convergence_trigger']) + annealing_trigger.trigger() + + if restore_from: + _restore_from_path(restore_from) + else: + _restore_from(dir_model) + + summary_writer = tf.summary.FileWriter( + os.path.join(expr_name, 'log'), sess.graph, flush_secs=30) + + epoch = 0 + while epoch < config_train.max_epochs: + train_data_name, train_op_name, mask_pattern = phases[ + annealing_trigger.user_state] + train_op = train_ops[train_op_name] + summary_op = summary_ops[train_op_name] + + print('epoch #{} {}:'.format( + epoch, (train_data_name, train_op_name, mask_pattern))) + + val_bleu = _eval_epoch(sess, summary_writer, 'val') + test_bleu = _eval_epoch(sess, summary_writer, 'test') + if triggered: + _restore_and_anneal() + continue + + step = tf.train.global_step(sess, global_step) + + print('epoch: {}, step: {}, val BLEU: {}, test BLEU: {}'.format( + epoch, step, val_bleu, test_bleu)) + + _train_epoch(sess, summary_writer, train_data_name, + train_op, summary_op) + if triggered: + _restore_and_anneal() + continue + + epoch += 1 + + step = tf.train.global_step(sess, global_step) + _save_to(ckpt_model, step) + + test_bleu = _eval_epoch(sess, summary_writer, 'test') + print('epoch: {}, test BLEU: {}'.format(epoch, test_bleu)) + + +if __name__ == '__main__': + main() + diff --git a/examples/differentiable_expected_bleu/prepare_data.py b/examples/differentiable_expected_bleu/prepare_data.py new file mode 100644 index 00000000..8a19075b --- /dev/null +++ b/examples/differentiable_expected_bleu/prepare_data.py @@ -0,0 +1,50 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Downloads data. +""" +import tensorflow as tf +import texar as tx + +import os + +# pylint: disable=invalid-name + +flags = tf.flags + +flags.DEFINE_string("data", "de-en", "Data to download [de-en|en-fr]") + +FLAGS = flags.FLAGS + +def prepare_data(): + """Downloads data. + """ + if FLAGS.data == 'de-en': + tx.data.maybe_download( + urls='https://drive.google.com/file/d/' + '1y4mUWXRS2KstgHopCS9koZ42ENOh6Yb9/view?usp=sharing', + path='./', + filenames='iwslt14_de-en.zip', + extract=True) + os.rename(os.path.join('data', 'iwslt14'), + os.path.join('data', 'iwslt14_de-en')) + else: + raise ValueError('Unknown data: {}'.format(FLAGS.data)) + +def main(): + """Entrypoint. + """ + prepare_data() + +if __name__ == '__main__': + main() diff --git a/texar/core/optimization.py b/texar/core/optimization.py index 1e24e9b3..af48c17d 100644 --- a/texar/core/optimization.py +++ b/texar/core/optimization.py @@ -125,7 +125,7 @@ def default_optimization_hparams(): :tf_main:`tf.clip_by_average_norm `, etc. "type" specifies the gradient clip function, and can be a function, - or its name or mudule path. If function name is provided, the + or its name or module path. If function name is provided, the function must be from module :tf_main:`tf < >` or :mod:`texar.custom`. "kwargs" specifies keyword arguments to the function, except arguments diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index c684911c..c8d09cfc 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -27,3 +27,4 @@ from texar.losses.adv_losses import * from texar.losses.rewards import * from texar.losses.entropy import * +from texar.losses.debleu import * diff --git a/texar/losses/debleu.py b/texar/losses/debleu.py new file mode 100644 index 00000000..1d1307db --- /dev/null +++ b/texar/losses/debleu.py @@ -0,0 +1,214 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Differentiable Expected BLEU loss +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +# pylint: disable=invalid-name, not-context-manager, protected-access, +# pylint: disable=too-many-arguments + +__all__ = [ + "debleu", +] + +def batch_gather(params, indices, name=None): + """This function is copied and modified from tensorflow 11.0. See + https://www.tensorflow.org/api_docs/python/tf/batch_gather for details. + Gather slices from `params` according to `indices` with leading batch dims. + This operation assumes that the leading dimensions of `indices` are dense, + and the gathers on the axis corresponding to the last dimension of `indices`. + More concretely it computes: + result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]] + Therefore `params` should be a Tensor of shape [A1, ..., AN, B1, ..., BM], + `indices` should be a Tensor of shape [A1, ..., AN-1, C] and `result` will be + a Tensor of size `[A1, ..., AN-1, C, B1, ..., BM]`. + In the case in which indices is a 1D tensor, this operation is equivalent to + `tf.gather`. + See also `tf.gather` and `tf.gather_nd`. + Args: + params: A Tensor. The tensor from which to gather values. + indices: A Tensor. Must be one of the following types: int32, int64. Index + tensor. Must be in range `[0, params.shape[axis]`, where `axis` is the + last dimension of `indices` itself. + name: A name for the operation (optional). + Returns: + A Tensor. Has the same type as `params`. + Raises: + ValueError: if `indices` has an unknown shape. + """ + + with tf.name_scope(name): + indices = tf.convert_to_tensor(indices, name="indices") + params = tf.convert_to_tensor(params, name="params") + indices_shape = tf.shape(indices) + params_shape = tf.shape(params) + + ndims = indices.shape.ndims + if ndims is None: + raise ValueError("batch_gather does not allow indices with unknown " + "shape.") + batch_indices = indices + indices_dtype = indices.dtype.base_dtype + accum_dim_value = tf.ones((), dtype=indices_dtype) + # Use correct type for offset index computation + casted_params_shape = tf.cast(params_shape, indices_dtype) + for dim in range(ndims-1, 0, -1): + dim_value = casted_params_shape[dim-1] + accum_dim_value *= casted_params_shape[dim] + start = tf.zeros((), dtype=indices_dtype) + step = tf.ones((), dtype=indices_dtype) + dim_indices = tf.range(start, dim_value, step) + dim_indices *= accum_dim_value + dim_shape = tf.stack( + [1] * (dim - 1) + [dim_value] + [1] * (ndims - dim), axis=0) + batch_indices += tf.reshape(dim_indices, dim_shape) + + flat_indices = tf.reshape(batch_indices, [-1]) + outer_shape = params_shape[ndims:] + flat_inner_shape = tf.reduce_prod(params_shape[:ndims]) + + flat_params = tf.reshape( + params, tf.concat([[flat_inner_shape], outer_shape], axis=0)) + flat_result = tf.gather(flat_params, flat_indices) + result = tf.reshape( + flat_result, tf.concat([indices_shape, outer_shape], axis=0)) + final_shape = indices.get_shape()[:ndims-1].merge_with( + params.get_shape()[:ndims -1]) + final_shape = final_shape.concatenate(indices.get_shape()[ndims-1]) + final_shape = final_shape.concatenate(params.get_shape()[ndims:]) + result.set_shape(final_shape) + return result + +def debleu(labels, probs, sequence_length, time_major=False, + min_fn=lambda x: tf.minimum(1., x), max_order=4, + weights=[.1, .3, .3, .3], epsilon=1e-9, name=None): + """Computes Differentiable Expected BLEU (DEBLEU). See + https://openreview.net/pdf?id=S1x2aiRqFX for details. + + Args: + labels: Target sequence token indexes, i.e. y* in the paper. + + - If :attr:`time_major` is `False` (default), this must be\ + a tensor of shape `[batch_size, max_time]`. + + - If `time_major` is `True`, this must be a tensor of shape\ + `[max_time, batch_size].` + probs: Probabilities generated by model, i.e. y in the paper. This must + have the shape of + `[max_time, batch_size, vocab_size]` or + `[batch_size, max_time, vocab_size]` according to + the value of `time_major`. + sequence_length: A tensor of shape `[batch_size]`. Time steps beyond + the respective sequence lengths will have zero losses. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`labels` and :attr:`probs` must have shape + `[max_time, batch_size, ...]`. If `False` + (default), they must have shape `[batch_size, max_time, ...]`. + min_fn (function, optional): A python function that implements the min + operation in Eq.14 in the paper. Default to tf.minimum(1., x). + max_order (int, optional): Maximum order of grams calculated. Default + to 4. + weights (optional): A tensor (or simply Python list) of shape + `[max_order]` of which the i-th scalar is the weight of (i+1) gram + precision. Default to `[0.1, 0.3, 0.3, 0.3]`. + epsilon (float, optional): A small value added before applying + logarithm in Eq.17 in the paper. This is in order to avoid infinite + gradients. Default to 1e-9. + name (str, optional): A name for the operation. + + Returns: + A tensor containing the loss of rank 0. + + Example: + + .. code-block:: python + + embedder = WordEmbedder(vocab_size=data.vocab.size) + decoder = BasicRNNDecoder(vocab_size=data.vocab.size) + + tm_helper = texar.modules.TeacherMaskSoftmaxEmbeddingHelper( + inputs=data_batch['text_ids'], + sequence_length=data_batch['length']-1, + embedding=embedder, + n_unmask=1, + n_mask=0, + tau=1.) + + outputs, _, _ = decoder(helper=tm_helper) + + loss = debleu( + labels=data_batch['text_ids'][:, 1:], + probs=outputs.sample_ids, + sequence_length=data_batch['length']-1) + + """ + with tf.name_scope(name, "debleu"): + X = probs # p_theta(y) + Y = labels # y* + + if time_major: + X = tf.transpose(X, [1, 0, 2]) + Y = tf.transpose(Y, [1, 0]) + + T_X = tf.shape(X)[1] # max T + T_Y = tf.shape(Y)[1] # max T* + + # XY denotes p(y_i=y*_j) + XY = batch_gather(X, tf.tile(tf.expand_dims(Y, 1), [1, T_X, 1])) + # YY denotes 1(y*_j=y*_j') + YY = tf.to_float(tf.equal(tf.expand_dims(Y, 2), tf.expand_dims(Y, 1))) + + maskX = tf.sequence_mask( + sequence_length + 1, maxlen=T_X + 1, dtype=tf.float32) + maskY = tf.sequence_mask( + sequence_length + 1, maxlen=T_Y + 1, dtype=tf.float32) + matchXY = tf.expand_dims(maskX, 2) * tf.expand_dims(maskY, 1) + matchYY = tf.minimum(tf.expand_dims(maskY, 2), + tf.expand_dims(maskY, 1)) + + tot = [] + o = [] + + for order in range(max_order): # order = n - 1 + # Eq.20 + matchXY = XY[:, : T_X - order, : T_Y - order] * matchXY[:, 1:, 1:] + matchYY = YY[:, : T_Y - order, : T_Y - order] * matchYY[:, 1:, 1:] + cntYX = tf.reduce_sum(matchXY, 1, keepdims=True) + cntYY = tf.reduce_sum(matchYY, 1, keepdims=True) + # Eq.14 + o_order = tf.reduce_sum(tf.reduce_sum( + min_fn(cntYY / (cntYX - matchXY + 1)) + * matchXY / tf.maximum(1., cntYY), + 2), 1) + # calculate (T - n + 1); max(1, .) is to avoid being divided by 0 + tot_order = tf.maximum(1, sequence_length - order) + tot.append(tot_order) + o.append(o_order) + + tot = tf.stack(tot, 1) + o = tf.stack(o, 1) + # Eq.15 + prec = tf.reduce_sum(o, 0) / tf.to_float(tf.reduce_sum(tot, 0)) + # add epsilon in order to avoid inf gradient + neglog_prec = -tf.log(prec + epsilon) + # Eq.17; constant about BP is omitted + loss = tf.reduce_sum(weights * neglog_prec, 0) + + return loss diff --git a/texar/modules/decoders/rnn_decoder_helpers.py b/texar/modules/decoders/rnn_decoder_helpers.py index 559f3c29..4af4f874 100644 --- a/texar/modules/decoders/rnn_decoder_helpers.py +++ b/texar/modules/decoders/rnn_decoder_helpers.py @@ -38,6 +38,7 @@ "_get_training_helper", "GumbelSoftmaxEmbeddingHelper", "SoftmaxEmbeddingHelper", + "TeacherMaskSoftmaxEmbeddingHelper", ] def default_helper_train_hparams(): @@ -185,6 +186,17 @@ def _get_training_helper( #pylint: disable=invalid-name return helper +def get_embedding_and_fn(embedding): + if isinstance(embedding, EmbedderBase): + embedding = embedding.embedding + + if callable(embedding): + raise ValueError("`embedding` must be an embedding tensor or an " + "instance of subclass of `EmbedderBase`.") + else: + return embedding, (lambda ids: tf.nn.embedding_lookup(embedding, ids)) + + class SoftmaxEmbeddingHelper(TFHelper): """A helper that feeds softmax probabilities over vocabulary to the next step. @@ -215,17 +227,7 @@ class SoftmaxEmbeddingHelper(TFHelper): def __init__(self, embedding, start_tokens, end_token, tau, stop_gradient=False, use_finish=True): - if isinstance(embedding, EmbedderBase): - embedding = embedding.embedding - - if callable(embedding): - raise ValueError("`embedding` must be an embedding tensor or an " - "instance of subclass of `EmbedderBase`.") - else: - self._embedding = embedding - self._embedding_fn = ( - lambda ids: tf.nn.embedding_lookup(embedding, ids)) - + self._embedding, self._embedding_fn = get_embedding_and_fn(embedding) self._start_tokens = tf.convert_to_tensor( start_tokens, dtype=tf.int32, name="start_tokens") self._end_token = tf.convert_to_tensor( @@ -326,3 +328,146 @@ def sample(self, time, outputs, state, name=None): sample_ids = tf.stop_gradient(sample_ids_hard - sample_ids) \ + sample_ids return sample_ids + + +class TeacherMaskSoftmaxEmbeddingHelper(TFTrainingHelper): + """A helper that implements the Teacher Mask described in the paper + https://openreview.net/pdf?id=S1x2aiRqFX. In an unmasked step, it feeds + softmax probabilities over vocabulary to the next step. In a masked step, + it feeds the one-hot distribution of the target labels (:attr:`inputs`) + to the next step. + Uses the softmax probability or one-hot vector to pass through word + embeddings to get the next input (i.e., a mixed word embedding). + In this implementation, all sequences in a batch shares the same teacher + mask. + + A subclass of + :tf_main:`TrainingHelper `. + Used as a helper to :class:`~texar.modules.RNNDecoderBase` :meth:`_build` + in training mode. + + Args: + inputs (2D Tensor): Target sequence token indexes. It should be a tensor + of shape `[batch_size, max_time]`. Must append both BOS and EOS + tokens to each sequence. + sequence_length (1D Tensor): Lengths of input token sequences. These + lengths should include the BOS tokens but exclude the EOS tokens. + embedding: An embedding argument (:attr:`params`) for + :tf_main:`tf.nn.embedding_lookup `, or an + instance of subclass of :class:`texar.modules.EmbedderBase`. + Note that other callables are not acceptable here. + n_unmask: An int scalar tensor denotes the mask pattern together with + :attr:`n_mask`. See the paper for details. + n_mask: An int scalar tensor denotes the mask pattern together with + :attr:`n_unmask`. See the paper for details. + tau (float, optional): A float scalar tensor, the softmax temperature. + Default to 1. + seed (int, optional): The random seed used to shift the mask. + stop_gradient (bool): Whether to stop the gradient backpropagation + when feeding softmax vector to the next step. + name (str, optional): A name for the module. + + Example: + + .. code-block:: python + + embedder = WordEmbedder(vocab_size=data.vocab.size) + decoder = BasicRNNDecoder(vocab_size=data.vocab.size) + + tm_helper = texar.modules.TeacherMaskSoftmaxEmbeddingHelper( + inputs=data_batch['text_ids'], + sequence_length=data_batch['length']-1, + embedding=embedder, + n_unmask=1, + n_mask=0, + tau=1.) + + outputs, _, _ = decoder(helper=tm_helper) + + loss = debleu( + labels=data_batch['text_ids'][:, 1:], + probs=outputs.sample_ids, + sequence_length=data_batch['length']-1) + + """ + + def __init__(self, inputs, sequence_length, embedding, n_unmask, + n_mask, tau=1., time_major=False, seed=None, + stop_gradient=False, name=None): + with tf.variable_scope(name, "TeacherMaskSoftmaxEmbeddingHelper", + [embedding, tau, seed, stop_gradient]): + super(TeacherMaskSoftmaxEmbeddingHelper, self).__init__( + inputs=inputs, + sequence_length=sequence_length, + time_major=time_major) + + self._embedding, self._embedding_fn = get_embedding_and_fn( + embedding) + self._tau = tau + self._seed = seed + self._stop_gradient = stop_gradient + + self._zero_next_inputs = tf.zeros_like( + self._embedding_fn(self._zero_inputs)) + + self._n_unmask = n_unmask + self._n_mask = n_mask + self._n_cycle = tf.add( + self._n_unmask, self._n_mask, name="n_cycle") + self._n_shift = tf.random_uniform( + [], maxval=self._n_cycle, dtype=self._n_cycle.dtype, + seed=self._seed, name="n_shift") + + @property + def sample_ids_dtype(self): + return tf.float32 + + @property + def sample_ids_shape(self): + return self._embedding.get_shape()[:1] + + @property + def n_unmask(self): + return self._n_unmask + + @property + def n_mask(self): + return self._n_mask + + def _is_masked(self, time): + return (time + self._n_shift) % self._n_cycle < self._n_mask + + def initialize(self, name=None): + finished = tf.equal(0, self._sequence_length) + all_finished = tf.reduce_all(finished) + next_inputs = tf.cond( + all_finished, + lambda: self._zero_next_inputs, + lambda: self._embedding_fn(self._input_tas.read(0))) + return (finished, next_inputs) + + def sample(self, time, outputs, state, name=None): + """Returns `sample_id` of shape `[batch_size, vocab_size]`. In an + unmasked step, it is softmax distributions over vocabulary with + temperature :attr:`tau`; in a masked step, it is one-hot + representations of :attr:`input` in the next step. + """ + next_time = time + 1 + sample_ids = tf.cond( + self._is_masked(next_time), + lambda: tf.one_hot(self._input_tas.read(next_time), + self._embedding.get_shape()[0]), + lambda: tf.nn.softmax(outputs / self._tau)) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + next_time = time + 1 + finished = (next_time >= self._sequence_length) + all_finished = tf.reduce_all(finished) + if self._stop_gradient: + sample_ids = tf.stop_gradient(sample_ids) + next_inputs = tf.cond( + all_finished, + lambda: self._zero_next_inputs, + lambda: tf.matmul(sample_ids, self._embedding)) + return (finished, next_inputs, state) diff --git a/texar/utils/__init__.py b/texar/utils/__init__.py index d22e2050..ab284e9c 100644 --- a/texar/utils/__init__.py +++ b/texar/utils/__init__.py @@ -29,3 +29,4 @@ from texar.utils.mode import * from texar.utils.average_recorder import * from texar.utils.utils_io import * +from texar.utils.triggers import * diff --git a/texar/utils/triggers.py b/texar/utils/triggers.py new file mode 100644 index 00000000..d4aefdaf --- /dev/null +++ b/texar/utils/triggers.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Triggers. +""" +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import pickle + +try: + import queue +except ImportError: + import Queue as queue + +#pylint: disable=invalid-name, too-many-arguments, too-many-locals + +__all__ = [ + "Trigger", + "ScheduledStepsTrigger", + "BestEverConvergenceTrigger", +] + + +class Trigger(object): + """This is the base class of all triggers. A trigger maintains some + user-defined :attr:`user_state` and does some :attr:`action` when certain + condition is met. Specifically, the user calls the trigger periodically. + Every time the trigger is called, it will send all arguments to + :meth:`_predicate`, which returns a boolean value indicates whether the + condition is met. Once the condition is met, the trigger will then execute + `user_state = action(user_state)` to update the :attr:`user_state`. + :attr:`user_state` should completely define the current state of the + trigger, and, therefore, enables saving and restoring :attr:`user_state`. + It is the user's responsibility to keep :attr:`action` away from any + possible corruption of restored state. + + Args: + initial_user_state: A (any kind of picklable) object representing the + initial :attr:`user_state`. + action (callable): A callable which is called to update + :attr:`user_state` every time the trigger is triggered. See above + for detailed explanation. + .. document private functions + .. automethod:: __call__ + """ + + def __init__(self, initial_user_state, action): + self._user_state = initial_user_state + if not callable(action): + raise ValueError("Action {} is not callable".format(action)) + self._action = action + + def _predicate(self, *args, **kwargs): + """Returns True when the condition is met and we should do something. + """ + raise NotImplementedError + + def trigger(self): + """Executes `user_state = action(user_state)`. User can manually call + this method to trigger it. + """ + self._user_state = self._action(self._user_state) + + def __call__(self, *args, **kwargs): + """The trigger must be called to update the internal state and + automatically triggers when the condition is found met. + + Returns: + A boolean denotes whether triggered this time. + """ + pred = self._predicate(*args, **kwargs) + if pred: + self.trigger() + return pred + + def _make_state(self, names): + return {name: getattr(self, name) for name in names} + + @property + def _state_names(self): + """Returns a list of names of attributes of the trigger object that can + be saved and restored as trigger state. + """ + return ['_user_state'] + + @property + def state(self): + """The current state which can be used to save and restore the trigger. + The state is consisted of the internal state used to determine whether + the condition is met, and the user-defined :attr:`user_state`. + """ + return self._make_state(self._state_names) + + @property + def user_state(self): + """The user-defined :attr:`user_state`. + """ + return self._user_state + + def restore_from_state(self, state): + """Restore the trigger state from the previous saved state. + + Args: + state: The state previously obtained by :attr:`state`. + """ + for name, value in state.items(): + setattr(self, name, value) + + def save_to_pickle(self, file): + """Write a pickled representation of the state of the trigger to the + open file-like object :attr:`file`. + + Args: + file: The open file-like object to which we write. As described in + pickle official document, it must have a `write()` method that + accepts a single string argument. + """ + pickle.dump(self.state, file) + + def restore_from_pickle(self, file): + """Read a string from the open file-like object :attr:`file` and + restore the trigger state from it. + + Args: + file: The open file-like object from which we read. As described in + pickle official document, it must have a `read()` method that + takes an integer argument, and a `readline()` method that + requires no arguments, and both methods should return a string. + """ + self.restore_from_state(pickle.load(file)) + + +class ScheduledStepsTrigger(Trigger): + """A trigger that triggers after the training step have iterated over some + user-designated steps. This means that it will trigger if there is at least + one `step` in user-designated set of :attr:`steps` within the range + `(last_called_step, current_step]`. + + There are **2 ways** provided to specify the set of :attr:`steps`: + + 1. :attr:`steps` is a callable. When calling + `steps(last_called_step, current_step)`, it is assumed to return + a boolean indicating whether there is at least one `step` in the set + within the range `(last_called_step, current_step]`. For example, + :code:`steps = lambda l, r: l // n != r // n` denotes the set + `{i * n for any integer i}` where `n` is some integer. This option + enables user to define any set of steps, even an infinite set. Note + that in this case the trigger will never trigger when being called + for the first time, because `last_called_step` is undefined at this + time. User can manually call it to specify an initial step before + training. + + 2. :attr:`steps` is a `list` or `tuple` containing numbers in ascending + order. These numbers compose the whole set. + + Args: + initial_user_state: A (any kind of picklable) object representing the + initial :attr:`user_state`. + action (callable): A callable which is called to update + :attr:`user_state` every time the trigger is triggered. + steps (list, tuple, or callable): Represents the user-designated set of + :attr:`steps` described above. + .. document private functions + .. automethod:: __call__ + """ + + def __init__(self, initial_user_state, action, steps): + super(ScheduledStepsTrigger, self).__init__(initial_user_state, action) + self._steps = steps + + if callable(self._steps): + self._last_called_step = None + + else: + self._index = 0 + + @property + def _state_names(self): + return super(ScheduledStepsTrigger, self)._state_names + [ + '_last_called_step' if callable(self._steps) else '_index'] + + @property + def last_called_step(self): + """The step when the trigger is latest called. + """ + return self._last_called_step + + def _predicate(self, step): + if callable(self._steps): + if self._last_called_step is not None: + ret = self._steps(self._last_called_step, step) + else: + ret = False + + self._last_called_step = step + + else: + ret = False + while self._index < len(self._steps) and \ + self._steps[self._index] <= step: + ret = True + self._index += 1 + + return ret + + def __call__(self, step): + """The trigger must be called to update the current training step + (:attr:`step`). + + Args: + step (int): Current training step to update. The training step must + be updated in ascending order. + + Returns: + A boolean denotes whether triggered this time. + """ + return super(ScheduledStepsTrigger, self).__call__(step) + + +class BestEverConvergenceTrigger(Trigger): + """A trigger that maintains the best value of a metric. It triggers when + the best value of the metric has not been updated for at least + :attr:`threshold_steps`. In order to avoid it triggers two frequently, it + will not trigger again within :attr:`minimum_interval_steps` once it + triggers. + + Args: + initial_user_state: A (any kind of picklable) object representing the + initial :attr:`user_state`. + action (callable): A callable which is called to update + :attr:`user_state` every time the trigger is triggered. + threshold_steps (int): Number of steps it should trigger after the best + value was last updated. + minimum_interval_steps (int): Minimum number of steps between twice + firing of the trigger. + .. document private functions + .. automethod:: __call__ + """ + + def __init__(self, initial_user_state, action, threshold_steps, + minimum_interval_steps): + super(BestEverConvergenceTrigger, self).__init__( + initial_user_state, action) + self._threshold_steps = threshold_steps + self._minimum_interval_steps = minimum_interval_steps + self._last_triggered_step = None + self._best_ever_step = None + self._best_ever_score = None + + def _predicate(self, step, score): + if self._best_ever_score is None or self._best_ever_score < score: + self._best_ever_score = score + self._best_ever_step = step + + if (self._last_triggered_step is None or + step - self._last_triggered_step >= + self._minimum_interval_steps) and \ + step - self._best_ever_step >= self._threshold_steps: + self._last_triggered_step = step + return True + return False + + def __call__(self, step, score): + """The trigger must be called to update the current training step + (:attr:`step`) and the current value of the maintained metric + (:attr:`score`). + + Args: + step (int): Current training step to update. The training step must + be updated in ascending order. + score (float or int): Current value of the maintained metric. + + Returns: + A boolean denotes whether triggered this time. + """ + return super(BestEverConvergenceTrigger, self).__call__(step, score) + + @property + def _state_names(self): + return super(BestEverConvergenceTrigger, self)._state_names + [ + '_last_triggered_step', '_best_ever_step', '_best_ever_score'] + + @property + def last_triggered_step(self): + """The step at which the Trigger last triggered. + """ + return self._last_triggered_step + + @property + def best_ever_step(self): + """The step at which the best-ever score is reached. + """ + return self._best_ever_step + + @property + def best_ever_score(self): + """The best-ever score. + """ + return self._best_ever_score diff --git a/texar/utils/triggers_test.py b/texar/utils/triggers_test.py new file mode 100644 index 00000000..979b95ed --- /dev/null +++ b/texar/utils/triggers_test.py @@ -0,0 +1,125 @@ +""" +Unit tests for triggers. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import random +import bisect + +from texar.utils.triggers import * + + +class TriggerTest(tf.test.TestCase): + """Tests :class:`~texar.utils.Trigger`. + """ + + def test(self): + trigger = Trigger(0, lambda x: x+1) + for step in range(100): + trigger.trigger() + self.assertEqual(trigger.user_state, step+1) + + +class ScheduledStepsTriggerTest(tf.test.TestCase): + """Tests :class:`~texar.utils.ScheduledStepsTrigger`. + """ + + def test(self): + for i in range(100): + n = random.randint(1, 100) + m = random.randint(1, n) + p = random.uniform(0, 0.3) + f = lambda l, r: l // n != r // n + trigger = ScheduledStepsTrigger(0, lambda x: x+1, f) + + last_called_step = None + + for step in range(n): + if random.random() < p: + if last_called_step is not None: + triggered_ = f(last_called_step, step) + else: + triggered_ = False + + last_called_step = step + + triggered = trigger(step) + + self.assertEqual(trigger.last_called_step, last_called_step) + self.assertEqual(triggered, triggered_) + + for i in range(100): + n = random.randint(1, 100) + m = random.randint(1, n) + p = random.uniform(0, 0.3) + q = random.uniform(0, 0.3) + steps = [step for step in range(n) if random.random() < q] + f = lambda l, r: bisect.bisect_right(steps, l) < \ + bisect.bisect_right(steps, r) + trigger = ScheduledStepsTrigger(0, lambda x: x+1, steps) + + last_called_step = -1 + + for step in range(n): + if random.random() < p: + triggered_ = f(last_called_step, step) + last_called_step = step + + triggered = trigger(step) + + self.assertEqual(triggered, triggered_) + + trigger = ScheduledStepsTrigger(0, lambda x: x+1, []) + for step in range(100): + trigger.trigger() + self.assertEqual(trigger.user_state, step+1) + + +class BestEverConvergenceTriggerTest(tf.test.TestCase): + """Tests :class:`~texar.utils.BestEverConvergenceTrigger`. + """ + + def test(self): + for i in range(100): + n = random.randint(1, 100) + seq = list(range(n)) + random.shuffle(seq) + threshold_steps = random.randint(0, n // 2 + 1) + minimum_interval_steps = random.randint(0, n // 2 + 1) + trigger = BestEverConvergenceTrigger( + 0, lambda x: x+1, threshold_steps, minimum_interval_steps) + + best_ever_step, best_ever_score, last_triggered_step = -1, -1, None + + for step, score in enumerate(seq): + if score > best_ever_score: + best_ever_step = step + best_ever_score = score + + triggered_ = step - best_ever_step >= threshold_steps and \ + (last_triggered_step is None or + step - last_triggered_step >= minimum_interval_steps) + if triggered_: + last_triggered_step = step + + triggered = trigger(step, score) + + self.assertEqual(trigger.best_ever_step, best_ever_step) + self.assertEqual(trigger.best_ever_score, best_ever_score) + self.assertEqual(trigger.last_triggered_step, + last_triggered_step) + self.assertEqual(triggered, triggered_) + + trigger = BestEverConvergenceTrigger(0, lambda x: x+1, 0, 0) + for step in range(100): + trigger.trigger() + self.assertEqual(trigger.user_state, step+1) + + +if __name__ == "__main__": + tf.test.main() +