diff --git a/baselines/t5/configs/tasks/paracrawl.gin b/baselines/t5/configs/tasks/paracrawl.gin new file mode 100644 index 000000000..2fe255f9a --- /dev/null +++ b/baselines/t5/configs/tasks/paracrawl.gin @@ -0,0 +1,78 @@ +# Copyright 2023 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Task-specific configurations for Paracrawl training and evaluation. +from __gin__ import dynamic_registration + +import seqio +import __main__ as train_script +from t5x import checkpoints +from t5x import utils +from t5x import decoding + +# Register necessary SeqIO Tasks/Mixtures. +import data.mixtures # local file import from baselines.t5 + +# Change to paracrawl_wmt_deen_precleaned for the baseline on precleaned data. +MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen' + +# Eval mixture with evaluation on both dev and test sets. +EVAL_MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen_eval_mixture' + +# Disable caching since ub tasks are not cached in the official directory. +USE_CACHED_TASKS = False + +BASE_LEARNING_RATE = 2.0 + +# Adjust checkpoint saving. +utils.SaveCheckpointConfig: + period = 5000 + dtype = 'float32' + keep = 8 # Keep the 8 best checkpoints. + save_dataset = False # Don't checkpoint dataset state. + checkpointer_cls = @checkpoints.SaveBestCheckpointer + +checkpoints.SaveBestCheckpointer: + metric_name_to_monitor = 'inference_eval/paracrawl/eval/bleu' + metric_mode = 'max' + keep_checkpoints_without_metrics = False + +train_script.train.infer_eval_dataset_cfg = @train_infer/utils.DatasetConfig() +train_script.train.inference_evaluator_cls = @seqio.Evaluator + +train_infer/utils.DatasetConfig: + mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME + task_feature_lengths = None + split = 'validation' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = False + module = None # %MIXTURE_OR_TASK_MODULE + +decoding.beam_search.max_decode_len = 256 + +# Disable JSON logger to reduce the cns storage required for inference. +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = None # Use all examples in the infer_eval dataset. + +seqio.JSONLogger.write_n_results = None # Write all inferences. + +utils.create_learning_rate_scheduler: + factors = 'constant * rsqrt_decay' + base_learning_rate = %BASE_LEARNING_RATE + warmup_steps = 5000 +