LAnguage Modelling Benchmarks is to tune and test Tensorflow language models. It was used in the following papers (alse see citations):
-
On the state of the art of evaluation in neural language models
See ./experiment/on-the-state/README.md for more.
-
See ./experiment/pushing-the-bounds/README.md for more.
-
See ./experiment/mogrifier/README.md for more.
The default dataset locations are ~/data/<dataset-name>/
. See
lib/config/{ptb,wikitext-2,wikitext-103,enwik8}.sh
for the defaults.
To train a small LSTM on Penn Treebank, run this script:
experiment/train_ptb_10m_lstm_d1.sh
In the script, model configuration, data files, etc are specified by setting variables:
training_file="ptb.train.txt"
validation_file="ptb.valid.text"
model="lstm"
hidden_size=500
These shell variables are passed as command line arguments to the python program. These options are documented in the reference section.
To test a trained model:
experiment/test.sh run "mymodel" "experiment_dir_of_training_run"
In the output, lines with final valid* xe:':
have the validation set
cross-entropy. Evaluation results are printed as they happen (see the section on
evaluation). Lines of special interest in the output are those
with final {valid,test}
in them. The format is the following:
final ${dataset}_${eval_method}[_${dropout_multiplier}][_t${softmax_temp}]
For eval_method=arithmetic
with
eval_dropout_multiplier=0.8
and
eval_softmax_temperature=0.9
results may look
like this after 200 optimization steps and 2 evaluations:
turn: 2 (eval), step: 200 (opt) (5.29/s)
final valid_mca_d0.8_t0.9 xe: 5.315
final test_mca_d0.8_t0.9 xe: 5.289
... except that training runs normally don't have the test set results (see
eval_on_test
). Test runs are pretty much training runs with
no optimization steps.
For example:
conda create -n tfp3.7 python=3.7 numpy scipy
conda activate tfp3.7
conda install cudatoolkit
conda install cudnn
conda install tensorflow-gpu=1.15
conda install tensorflow-probability-gpu=1.15
conda install tensorflow-probability
pip install -e <path-to-git-checkout>
A value given for an option gets converted to the data type corresponding to the
option in question. In the following, options are listed with their data type
and default value (e.g. model (string, lstm)
means that the variable model
has type string
and default value lstm
). If there is no default value
listed, then the option is mandatory.
-
The file with the training data, one line per example. Newlines are translated to an end-of-sentence token.
-
A file of the same format as
training_file
. During training, the model is evaluated periodically on data from validation_file. Most notably, early stopping and hyperparameter tuning are based on performance on this set of examples. This must not be specified when doing cross-validation as in that case, the evaluation set is constructed from the training set. -
A file of the same format as
training_file
. During training, the model evaluated periodically on data fromtest_file
and the results are logged. As opposed tovalidation_file
, this dataset have no affect on training or tuning. The empty string (the default) turns off evaluation on the test set. -
The encoding of
training_file
,validation_file
andtest_file
. -
Whether to do word or character based modelling. If word based, lines are split at whitespace into tokens. Else, lines are simply split into characters.
-
If true, iterate over examples (lines in the data files) in random order. If false, iterate mostly sequentially carrying over model from the previous example to the next.
-
An upper bound on the total number of trainable parameters over all parts of the model (including the recurrent cell and input/output embeddings). If this is set to a meaningful value (i.e. not -1, the default), then
hidden_size
is set to the largest possible value such that the parameter budget is not exceeded. -
share_input_and_output_embeddings
(boolean, false)Whether the input and output embeddings are the same matrix (transposed) or independent (the default). If true, then
input_embedding_size
and output_embedding_size must be the same. -
input_embedding_size
(integer, -1)The length of the vector that represents an input token. If -1 (the default), then it's determined by
input_embedding_ratio
. -
output_embedding_size
(integer, -1)The length of the vector that represents an output token. If -1 (the default), then it's determined by output_embedding_ratio. If - after applying the defaulting rules -
output_embedding_size
is not equal tohidden_size
, then the cell output is linearly transformed tooutput_embedding_size
before the final linear transform into the softmax. -
input_embedding_ratio
(float, 1.0)If
input_embedding_size
is not specified (i.e. -1), then it's set toround(input_embedding_ratio*hidden_size)
. -
output_embedding_ratio
(float, -1.0)If
output_embedding_size
is not specified (i.e. -1), then it's set toround(output_embedding_ratio*hidden_size)
. The default value of -1, makesoutput_embedding_ratio
default to the value ofinput_embedding_ratio
so that one can tune easily withshare_input_and_output_embeddings
=true
. -
mos_num_components
(integer, 0)See Breaking the softmax bottleneck. The default of 0 turns this feature off.
-
embedding_dropout
(float, 0.0)The probability that all occurrences of a word are dropped from a batch.
-
The probability that a token will be dropped (i.e. the input at that step becomes zero). This can be thought of as a version of
embedding_dropout
that has different masks per time step. -
The dropout rate (here and elsewhere, 0 means deterministic operation) for the input to the first layer (i.e. just after the input embeddings). This drops out individual elements of the embedding vector.
-
The dropout rate for just after the cell output.
-
downprojected_output_dropout
(float, -1.0)The dropout rate for the projection of the cell output. Only used if
output_embedding_size
is different fromhidden_size
or ifmos_num_components
is not 1. Defaults tooutput_dropout
if set to -1. -
shared_mask_dropout
(boolean, false)Whether to use the time same dropout mask for all time steps for
input_dropout
,inter_layer_dropout
,output_dropout
anddownprojected_output_dropout
. -
Whether to compute the logits from the cell output in a single operation or per time step. The single operation is faster but uses more GPU memory. Also, see
swap_memory
.
-
One of
lstm
,rhn
(Recurrent Highway Network),nas
. -
The number of same-sized LSTM cells stacked on top of each other, or the number of processing steps per input an RHN does. Has no effect on NAS.
-
lstm_skip_connection
(boolean, true)If true, for multi-layer (num_layers>1) LSTMs, the output is computed as the sum of the outputs of the individual layers.
-
feature_mask_rounds
(integer, 0)The Mogrifier LSTM is implemented in terms of the feature masking option. The LSTM specific feature masking option involves gating the input and the state before they are used for calculating all the other stuff (i.e.
i
,j
,o
,f
). This allows input features to be reweighted based on the state, and state features to be reweighted based on the input. See the Mogrifier LSTM paper for details.When
feature_mask_rounds
is 0, there is no extra gating in the LSTM. When 1<=, the input is gated:x *= 2*sigmoid(affine(h)))
. When 2<=, the state is gated:h *= 2*sigmoid(affine(x)))
. For higher number of rounds, the alternating gating continues. -
feature_mask_rank
(integer, 0)If 0, the linear transforms described above are full rank, dense matrices. If
0, then the matrix representing the linear transform is factorized as the product of two low rank matrices (
[*, rank]
and[rank, *]
). This reduces the number of parameters greatly. -
A comma-separated list of integers representing the number of units in the state of the recurrent cell per layer. Must not be longer than
num_layers
. If it's shorter, then the missing values are assumed to be equal to the last specified one. For example, for a 3 layer network"512,256"
results in the first layer having 512 units, the second and the third having 256. If "-1" (the default), an attempt is made to deduce it fromnum_params
assuming all layers have the same size. -
Whether to perform Layer Normalization (currently only implemented for LSTMs).
-
activation_fn
(string, tf.tanh)The non-linearity for the update candidate ('j') and the output ('o') in an LSTM, or the output ('h') in an RHN.
-
tie_forget_and_input_gates
(boolean, false)In an LSTM, whether the input gate ('i') is set to 1 minus the forget gate ('f'). In an RHN, whether the transform gate ('t') is set to 1 minus the carry gate ('c').
-
cap_input_gate
(boolean, true)Whether to cap the input gate at 1-f if
tie_forget_and_input_gates
is off. Currently only affects LSTMs. This makes learning more stable, especially at the early stages of training. -
trainable_initial_state
(boolean, true)Whether the initial state of the recurrent cells is allowed to be learnt or is set to a fixed zero vector. In non-episodic mode, this switch is forced off.
-
inter_layer_dropout
(float, 0.0)The input dropout for layers other than the first one. Defaults to no dropout, but setting it to -1 makes it inherit
input_dropout
. It has no effect on RHNs, since the input is not fed to their higher layers. -
This is the dropout rate for the recurrent state from the previous time step ('h' in an LSTM, 's' in an RHN). See Yarin Gal's "A Theoretically Grounded Application of Dropout in Recurrent Neural Networks". The dropout mask is the same for all time steps of a specific example in one batch.
-
This is the Recurrent Dropout (see "Recurrent Dropout without Memory Loss") rate on the update candidate ('j' in an LSTM, 'h' in an RHN). Should have been named Update Dropout.
-
If set to a positive value, the cell state ('c' in an LSTM, 's' in an RHN) is clipped to the
[-cell_clip, cell_clip]
range after each iteration.
-
model_average
(string, arithmetic)Pushing the bounds of dropout makes the point that the actual dropout objective being optimized is a lower bound of the true objectives of many different models. If we construct the lower bound from multiple samples though (a'la IWAE), the lower bound will get tighter.
model_average
is the training time equivalent ofeval_method
and determines what kind of model (and consequently, averaging) is to be used. One ofgeometric
,power
andarithmetic
. Only in effect ifnum_training_samples
> 1
. -
num_training_samples
(integer, 1)The number of samples from which to compute the objective (see
model_average
). Each training example being presented is run through the networknum_training_samples
times so the effective batch size isbatch_size
* num_training_samples
. Increasing the number of samples doesn't seems to help generalization, though. -
The L2 penalty on all trainable parameters.
-
The L1 penalty on all trainable parameters.
-
activation_norm_penalty
(float, 0.0)Activation Norm Penalty (Regularizing and optimizing LSTM language models by Merity et al).
-
drop_state_probability
(float, 0.0)In non-episodic mode, model state is carried over from batch to batch. Not feeding back the state with
drop_state_probability
encourages the model to work well starting from the zero state which brings it closer to the test regime.
-
embedding_init_factor
(float, 1.0)All input embedding weights are initialized with a truncated normal distribution with mean 0 and:
stddev=sqrt(embedding_init_factor/input_embedding_size)
-
scale_input_embeddings
(boolean, false)This is not strictly an initialization option, but it serves a similar purpose. Input embeddings are initialized from a distribution whose variance is inversely proportional to
input_embedding_size
. Since every layer in the network is initialized to produce output with approximately the same variance as its input, changing the embedding size has a potentially strong, undesirable effect on optimization. Setscale_input_embeddings
totrue
to multiply input embeddings bysqrt(input_embedding_size)
to cancel this effect.As opposed to just changing
embedding_init_factor
, this multiplication has the benefit that the input embedding matrix is of the right scale for use as the output embedding matrix shouldshare_input_and_output_embeddings
be turned on. -
The various weight matrices in the recurrent cell are initialized independently (of which there are 8 in an LSTM, 4/2 in an RHN) with
stddev=sqrt(cell_init_factor/fan_in)
while biases are initialized with
stddev=sqrt(cell_init_factor/hidden_size)
-
Sometimes initializing the biases of the forget gate ('f') in the LSTM (or that of the carry gate ('c') in an RHN) to a small positive value (typically 1.0, the default) makes the initial phase of optimization faster. Higher values make the network forget less of its state over time. With deeper architectures and no skip connections (see
num_layers
andlstm_skip_connection
), this may actually make optimization harder.The value of
forget_bias
is used as the mean of the distribution used for initialization with unchanged variance. -
output_init_factor
(float, 1.0)If
share_input_and_output_embeddings
is false, then the output projection (also known as the output embeddings) is initialized withstddev=sqrt(output_init_factor/fan_in)
If
share_input_and_output_embeddings
is true, then this only affects the linear transform of the cell output (seeoutput_embedding_size
).
-
steps_per_turn
(integer, 1000)The number of optimization steps between two successive evaluations. After this many steps performance is evaluated and logged on the training, validation and test sets (if specified). One so called turn consists of
steps_per_turn
optimization steps. -
The number of evaluations beyond which training cannot continue (also see early stopping).
-
print_training_stats_every_num_steps
(integer, 1000)Debug printing frequency.
-
optimizer_type
(string, rmsprop)The optimizer algorithm. One of
rmsprop
,adam
,adagrad
,adadelta
andsgd
. -
RMSPROP is actually Adam with
beta1=0.0
so that Adam's highly useful correction to the computed statistics is in effect which allows higher initial learning rates. Only applies whenoptimizer_type
is
rmsprop`. -
Similar to
adam_epsilon
. Only applies whenoptimizer_type
isrmsprop
. -
If non-zero, gradients are rescaled so that their norm does not exceed
max_grad_norm
. -
Batch size for training. Also, the evaluation batch size unless
min_non_episodic_eval_examples_per_stripe
overrides it. -
accum_batch_size
(integer, -1)The number of examples that are fed to the network at the same time. Set this to a divisor of
batch_size
to reduce memory usage at the cost of possibly slower training. Usingaccum_batch_size
does not change the results. -
For episodic operation, examples that have more tokens than this are truncated when the training and test files when loaded. For non-episodic operation, this is the window size of the truncated backprop.
-
trigger_averaging_turns
(integer, -1)The number of turns of no improvement on the validation set, after which weight averaging is turned on. Weight averaging is a trivial generalization of the idea behind Averaged SGD: it keeps track of the average weights, updating the average after each optimization step. Weight averaging does not affect training directly, only through evaluation. This feature is an alternative to dropping the learning rate.
-
trigger_averaging_at_the_latest
(integer, -1)If optimization reaches turn
trigger_averaging_at_the_latest
, then it is ensured that averaging is turned on. Set this to be somewhat smaller thanturns
so that all runs get at least one drop which should the results more comparable.
-
drop_learning_rate_turns
(integer, -1)If the validation score doesn't improve for
drop_learning_rate_turns
number of turns, then the learning rate is multiplied bydrop_learning_rate_multiplier
, possibly repeatedly. -
drop_learning_rate_multiplier
(float, 1.0)Set this to a value less than 1.0.
-
drop_learning_rate_at_the_latest
(integer, -1)If optimization reaches turn
drop_learning_rate_multiplier_at_the_latest
without having yet dropped the learning rate, then it is dropped regardless of whether the curve is still improving or not. Set this to be somewhat smaller thanturns
so that all runs get at least one drop which should the results more comparable.
-
early_stopping_turns
(integer, -1)Maximum number of turns without improvement in validation cross-entropy before stopping.
-
early_stopping_rampup_turns
(integer, 0)The effective
early_stopping_turns
starts out at 1 and is increased linearly to the specifiedearly_stopping_turns
inearly_stopping_rampup_turns
turns. -
early_stopping_worst_xe_target
(float, '')If the estimated best possible validation cross-entropy (extrapolated from the progress made in the most recent
early_stopping_turns
(subject to rampup) is worse thanearly_stopping_worst_xe_target
, then training is stopped. This is actually a string of comma separated floats. The first value is in effect when the learning rate has not been dropped yet. The second value is effect if it has been dropped once and so on. The last element of the list applies to any further learning rate drops. -
early_stopping_slowest_rate
(float, 0.0)The rate is defined as the average improvement in validation cross-entropy over the effective
early_stopping_turns
(seeearly_stopping_rampup_turns
). If the rate is less thanearly_stopping_slowest_rate
, then stop early.
-
crossvalidate
(boolean, false)If true, randomly split the training set into
crossvalidation_folds
folds, evaluate performance on each and average the cross-entropies. Repeat the entire process forcrossvalidation_rounds
and average the averages. -
crossvalidation_folds
(integer, 10)Number of number of folds to split the training set into.
-
crossvalidation_rounds
(integer, 1)If
crossvalidate
, then do this many rounds ofcrossvalidate_folds
-fold crossvalidation. Set this to a value larger than one if the variance of the cross-validation score over the random splits is too high.
The model being trained is evaluated periodically (see turns
and
steps_per_turn
) on the validation set (see
validation_file
) and also on the training set (see
training_file
). Evaluation on the training set is different
from the loss as it does not include regularization terms such as
l2_penalty
and is performed the same way as evaluation on the
validation set (see eval_method
).
To evaluate a saved model one typically wants to do no training, disable saving of checkpoints and evaluate on the test set which corresponds to this:
turns=0
save_checkpoints=false
eval_on_test=true
Furthermore, load_checkpoint
, and in all likelihood
config_file
must be set. This is all taken care of by the
experiment/test.sh
script.
-
max_training_eval_batches
(integer, 100)When evaluating performance on the training set, it is enough to get a rough estimate. If specified, at most
max_training_eval_batches
number of batches will be evaluated. Set this to zero, to turn off evaluation on the training set entirely. Set it to -1 to evaluate on the entire training set. -
max_eval_eval_batches
(integer, -1)Evaluation can be pretty expensive with large datasets. For expediency, one can impose a limit on the number of batches of examples to work with on the validation test.
-
max_test_eval_batches
(integer, -1)Same as
max_eval_eval_batches
but for the test set. -
min_non_episodic_eval_examples_per_stripe
(integer, 100)By default, evaluation is performed using the training batch size causing each "stripe" in a batch to be over rougly dataset_size/batch_size number of examples. With a small dataset in a non-episodic setting, that may make the evaluation quite pessimistic. This flag ensures that the batch size for evaluation is small enought that at least this many examples are processed in the same stripe.
-
Even if
test_file
is provided, evaluation on this test dataset is not performed by default. Set this to true to do that. Flipping this switch makes it easy to test a model by loading a checkpoint and its saved configuration without having to remember what the dataset was. -
eval_method
(string, deterministic)One of
deterministic
,geometric
,power
andarithmetic
. This determines how dropout is applied at evaluation.deterministic
is also known as standard dropout: dropout is turned off at evaluation time and a single deterministic pass propagates the expectation of each unit through the network.geometric
performs a renormalized geometric average of predicted probabilities over randomly sampled dropout masks.power
computes the power mean with exponenteval_power_mean_power
.arithmetic
computes the arithmetic average.
See Pushing the bounds of dropout for a more detailed discussion.
-
The number of samples to average probabilities over at evaluation time. Needs some source of stochasticity (currently only dropout) to be meaningful. When it's zero, the model is run in deterministic mode. Training evaluation is always performed in deterministic mode for expediency.
-
eval_softmax_temperature
(float, 1.0)Set this to a value lower than 1 to smoothen the distribution a bit at evaluation time to counter overfitting. Set it to a value between -1 and 0 to search for the optimal value between -value and 1 on the validation set. For example,
eval_softmax_temperature=-0.8
will search for the optimal temperature between 0.8 and 1.0. -
eval_power_mean_power
(float, 1.0)The exponent of the renormalized power mean to compute predicted probabilities. Only has an effect if
eval_method=power
. -
eval_dropout_multiplier
(float, 1.0)At evaluation time all dropout probabilities used for training are multiplied by this. Does not affect the
eval_method=deterministic
case. See Pushing the bounds of dropout for details. -
validation_prediction_file
(string)The name of the file where log probabilities of for the validation file are written. The file gets superseded by a newer version each time the model is evaluated. The file lists tokens and predicted log probabilities on alternating lines. Currently only implemented for deterministic evaluation.
-
Whether model weights shall be updated at evaluation time (see [Dynamic Evaluation of Neural Sequence Models][https://arxiv.org/abs/1709.07432] by Krause et al.). This forces batch size at evaluation time to 1 which makes it very slow, so turn it is best to leave it off until the final evaluation.
Whereas RMSProp maintains an online estimate of gradient variance, dynamic evaluation bases its estimate on training statistics which are affected by max_training_eval_batches and batch_size.
Also, when doing dynamic evaluation it might make sense to turn off some regularizers such as l2_penalty, or hacks like max_grad_norm.
-
dyneval_learning_rate
(float, 0.001)The learning rate for dynamic evaluation.
-
dyneval_decay_rate
(float, 0.02)The rate with which weights revert to the mean which is defined as what was trained.
-
This serves a similar purpose to
rmsprop_epsilon
, but for dynamic evaluation.
-
name
(string, see below) <name="name">The name of the experiment. Defaults to the git version concatenated with the basename of the script (without the
.sh
). Seeexperiment_dir
. -
experiment_dir
(string, ./ +$name
)Directory for saving configuration, logs and checkpoint files.
Lamb's git version is saved in
lamb_version
along with any uncommitted changes in the checkout (if in a git tree).stdout
andstderr
are also captured.If
save_checkpoints
is true, checkpoints are saved here. Also seesave_config
. -
All options are saved in
$experiment_dir/config
except for which it doesn't make sense: -
ensure_new_experiment
(boolean, true)If
ensure_new_experiment
is true, a random suffix is appended toexperiment_dir
to ensure the experiment starts afresh. When ensure_new_experiment is false andexperiment_dir
exists, the last checkpoint will be loaded on startup from that directory. -
This is to load
$experiment_dir/config
that gets saved automatically whensave_checkpoints
is true. It is not needed if one usesexperiment/test.sh
for evaluation. If a configuration option is set explicitly and is also in the configuration file, then the explicit version overrides the one in the configuration file.
-
save_checkpoints
(boolean, true)Whether to save any checkpoints.
save_checkpoints
also affects saving of the configuration (seeconfig_file
.If save_checkpoints is true, then two checkpoints are saved:
$experiment_dir/best
and$experiment_dir/last
. Iflast
exists, it will be loaded automatically on startup and training will continue from that state. If that's undesirable, use a differentexperiment_dir
or delete the checkpoint manually. Thebest
checkpoint corresponds to the best validation result seen so far during preriodic model evaluation during training. -
The name of the checkpoint file to load instead of loading
$experiment_dir/last
or randomly initializing. Absolute or relative toexperiment_dir
. -
load_optimizer_state
(boolean, true)Set this to
false
to preventload_checkpoint
from attempting to restore optimizer state. This effectively reinitializes the optimizer and also allows changing the optimizer type. It does not affect automatic loading of the latest checkpoint (seeexperiment_dir
).
-
The random seed. Both python and tensorflow seeds are initialized with this value. Due to non-determinism in tensorflow, training runs are not exactly reproducible even with the same seed.
-
Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, but slows things down a bit.
-
log_device_placement
(boolean, false)Log tensorflow device placement.
This is not an official Google product.
-
On the state of the art of evaluation in neural language models
@inproceedings{ melis2018on, title={On the State of the Art of Evaluation in Neural Language Models}, author={G{\'a}bor Melis and Chris Dyer and Phil Blunsom}, booktitle={International Conference on Learning Representations}, year={2018}, url={https://openreview.net/forum?id=ByJHuTgA-}, }
-
@article{melis2018pushing, title={Pushing the bounds of dropout}, author={Melis, G{\'a}bor and Blundell, Charles and Ko{\v{c}}isk{\'y}, Tom{\'a}{\v{s}} and Hermann, Karl Moritz and Dyer, Chris and Blunsom, Phil}, journal={arXiv preprint arXiv:1805.09208}, year={2018} }
-
@article{melis2020mogrifier, title={Mogrifier LSTM}, author={Melis, G{\'a}bor and Ko{\v{c}}isk{\'y}, Tom{\'a}{\v{s}} and Blunsom, Phil}, booktitle={International Conference on Learning Representations}, year={2020}, url={https://openreview.net/forum?id=SJe5P6EYvS}, }