Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paracrawl baseline task #1296

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions baselines/t5/configs/tasks/paracrawl.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2023 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Task-specific configurations for Paracrawl training and evaluation.
from __gin__ import dynamic_registration

import seqio
import __main__ as train_script
from t5x import checkpoints
from t5x import utils
from t5x import decoding

# Register necessary SeqIO Tasks/Mixtures.
import data.mixtures # local file import from baselines.t5

# Change to paracrawl_wmt_deen_precleaned for the baseline on precleaned data.
MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen'

# Eval mixture with evaluation on both dev and test sets.
EVAL_MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen_eval_mixture'

# Disable caching since ub tasks are not cached in the official directory.
USE_CACHED_TASKS = False

BASE_LEARNING_RATE = 2.0

# Adjust checkpoint saving.
utils.SaveCheckpointConfig:
period = 5000
dtype = 'float32'
keep = 8 # Keep the 8 best checkpoints.
save_dataset = False # Don't checkpoint dataset state.
checkpointer_cls = @checkpoints.SaveBestCheckpointer

checkpoints.SaveBestCheckpointer:
metric_name_to_monitor = 'inference_eval/paracrawl/eval/bleu'
metric_mode = 'max'
keep_checkpoints_without_metrics = False

train_script.train.infer_eval_dataset_cfg = @train_infer/utils.DatasetConfig()
train_script.train.inference_evaluator_cls = @seqio.Evaluator

train_infer/utils.DatasetConfig:
mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME
task_feature_lengths = None
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
module = None # %MIXTURE_OR_TASK_MODULE

decoding.beam_search.max_decode_len = 256

# Disable JSON logger to reduce the cns storage required for inference.
seqio.Evaluator:
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
num_examples = None # Use all examples in the infer_eval dataset.

seqio.JSONLogger.write_n_results = None # Write all inferences.

utils.create_learning_rate_scheduler:
factors = 'constant * rsqrt_decay'
base_learning_rate = %BASE_LEARNING_RATE
warmup_steps = 5000

Loading