Skip to content

Commit

Permalink
T5 finetuning on glue tasks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 347453698
  • Loading branch information
henrykmichalewski authored and copybara-github committed Dec 14, 2020
1 parent d288e96 commit 2f490e8
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 0 deletions.
1 change: 1 addition & 0 deletions trax/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def data_configure(*args, **kwargs):
shuffle = inputs.shuffle
TFDS = data_configure(tf_inputs.TFDS)
CreateBertInputs = data_configure(tf_inputs.CreateBertInputs)
CreateT5GlueInputs = data_configure(tf_inputs.CreateT5GlueInputs)
Tokenize = data_configure(tf_inputs.Tokenize)
ConvertToUnicode = data_configure(tf_inputs.ConvertToUnicode)
tokenize = tf_inputs.tokenize
Expand Down
69 changes: 69 additions & 0 deletions trax/data/tf_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,3 +1137,72 @@ def get_glue_key(task_name=gin.REQUIRED):
raise KeyError(
f'Wrong task name entered, available glue tasks: {list(glue_keys.keys())}. Entered: {task_name}'
)


def get_glue_t5_labels(dataset_name):
"""Get glue labels for T5 from the task name."""
ext_task_name = dataset_name if dataset_name.startswith(
'glue') else f'glue/{dataset_name}'
try:
# Labels inferred from the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
glue_t5_labels = {
'glue/cola': ('not_acceptable', 'acceptable'),
'glue/sst2': ('entailment', 'not_entailment'),
'glue/mrpc': ('not_equivalent', 'equivalent'),
'glue/qqp': ('not_duplicate', 'duplicate'),
# Requires processing of floats
# 'glue/stsb': ('sentence1', 'sentence2'),
'glue/mnli': ('entailment', 'neutral', 'contradiction'),
'glue/qnli': ('entailment', 'not_entailment'),
'glue/rte': ('entailment', 'not_entailment'),
# Used for evaluation and for training of T5.
# As explained in Section 2.4 of https://arxiv.org/pdf/1910.10683.pdf
# it has an overlap with WSC from Super-GLUE.
# 'glue/wnli': ('sentence1', 'sentence2'),
}
return glue_t5_labels[ext_task_name]
except KeyError:
raise KeyError(
f'Wrong task name entered, available glue tasks: {list(glue_t5_labels.keys())}. Entered: {dataset_name}'
)


def CreateT5GlueInputs( # pylint: disable=invalid-name
dataset_name='glue/qnli',
label_names=('entailment', 'not_entailment'),
train=True):
"""Prepares glue inputs for T5 models using standard T5 preprocessor."""

label_names = get_glue_t5_labels(dataset_name)
benchmark_name = dataset_name.split('/')[1]
if train:
dataset = tfds.load(name=dataset_name, split='train')
elif dataset_name == 'glue/mnli':
# TODO(henrykm): The other option for mnli is validation_mismatched
dataset = tfds.load(name=dataset_name, split='validation_matched')
else:
dataset = tfds.load(name=dataset_name, split='validation')
proc_dataset = generic_text_dataset_preprocess_fn(
dataset,
spm_path=t5.data.DEFAULT_SPM_PATH,
text_preprocess_fns=[
lambda ds, training: t5.data.preprocessors.glue( # pylint: disable=g-long-lambda
ds,
benchmark_name=benchmark_name,
label_names=label_names)
],
copy_plaintext=True,
debug_print_examples=True,
debug_print_examples_rate=1.0)

def t5_yield_examples(generator=None):
del generator
while True:
for example in proc_dataset:
input_values = example['inputs']
target_values = example['targets']
yield (fastmath.numpy.array(input_values),
fastmath.numpy.array(target_values),
fastmath.numpy.array([1] * len(target_values)))

return t5_yield_examples
111 changes: 111 additions & 0 deletions trax/supervised/configs/t5_glue_classification.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import trax.layers
import trax.models
import trax.data
import trax.optimizers
import trax.supervised.lr_schedules
import trax.supervised.trainer_lib
import trax.models.research.bert
import trax.layers.metrics

include 'c4.gin'

# See https://www.tensorflow.org/datasets/catalog/glue -- valid dataset_name
# values are: glue/cola, glue/sst2, glue/mrpc, glue/qqp, glue/stsb, glue/mnli,
# glue/qnli, glue/rte, glue/wnli.

dataset_name = 'glue/qnli'

# Corresponds roughly to T5 'large' ~ 770m params, i.e. T5's `bi_v1_large.gin`.
d_model = 1024
d_ff = 4096
n_heads = 16
n_layers = 24
attn_kv = 64
dropout = 0.1
vocab_size = 32000

ff_chunk_size = 0
ff_sparsity = 0
loss_sparsity = 0

enc_attn_type = @Attention

MultiplicativeModularCausalAttention.sparsity = 16
MultiplicativeConvCausalAttention.sparsity = 16
MultiplicativeConvCausalAttention.length_kernel_size = 3

dec_attn_type = @CausalAttention

# Parameters for TFDS data pipeline:
# ==============================================================================
make_inputs.train_stream = [
@train/data.CreateT5GlueInputs(),
@data.Shuffle(),
@data.PadToLength(),
@data.TruncateToLength(),
@data.Batch()
]
make_inputs.eval_stream = [
@eval/data.CreateT5GlueInputs(),
@data.Shuffle(),
@data.PadToLength(),
@data.TruncateToLength(),
@data.Batch()
]
train/data.CreateT5GlueInputs.dataset_name = %dataset_name
train/data.CreateT5GlueInputs.train = True
eval/data.CreateT5GlueInputs.dataset_name = %dataset_name
eval/data.CreateT5GlueInputs.train = False

data.PadToLength.len_map = {0: 512, 1: 512, 2: 512}
data.PadToLength.pad_value = {0: 0, 1: 0, 2:0}
data.TruncateToLength.len_map = {0: (256,), 1: (256,), 2: (256,)}
data.Batch.batch_size = 4

# Parameters for train:
# ==============================================================================
train.init_checkpoint = 'path_to_the_checkpoint'
train.optimizer = @trax.optimizers.Adam
train.eval_frequency = 20
train.eval_steps = 10
train.inputs = @trax.data.inputs.make_inputs
train.model = @trax.models.ConfigurableTransformer
train.steps = 200000
train.checkpoint_highest = 'accuracy'

# Parameters for ConfigurableTransformer:
# ==============================================================================
ConfigurableTransformer.d_model = %d_model
ConfigurableTransformer.d_ff = %d_ff
ConfigurableTransformer.dropout = %dropout
ConfigurableTransformer.ff_dropout = %dropout
ConfigurableTransformer.ff_chunk_size = %ff_chunk_size
ConfigurableTransformer.ff_sparsity = %ff_sparsity
ConfigurableTransformer.max_len = %max_length
ConfigurableTransformer.mode = 'train'
ConfigurableTransformer.n_heads = %n_heads
ConfigurableTransformer.n_encoder_layers = %n_layers
ConfigurableTransformer.n_decoder_layers = %n_layers
ConfigurableTransformer.input_vocab_size = %vocab_size
ConfigurableTransformer.encoder_attention_type = %enc_attn_type
ConfigurableTransformer.encoder_decoder_attention_type = %dec_attn_type
ConfigurableTransformer.loss_sparsity = %loss_sparsity

# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 1e-3
multifactor.factors = 'constant'
22 changes: 22 additions & 0 deletions trax/supervised/configs/t5_sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

dataset_name: [
'glue/mrpc',
'glue/sst2',
'glue/qqp',
'glue/mnli',
'glue/qnli',
'glue/rte',
]
23 changes: 23 additions & 0 deletions trax/supervised/configs/t5_sweep_temperature.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

dataset_name: [
'glue/mrpc',
'glue/sst2',
'glue/qqp',
'glue/mnli',
'glue/qnli',
'glue/rte',
]
ConfigurableTransformer.ff_sparsity: ['64 64 0.0 1.0', '64 64 0.0 0.0']

0 comments on commit 2f490e8

Please sign in to comment.