Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated BERT processing to current interface #1223

Merged
merged 1 commit into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions trax/models/research/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax
import tensorflow as tf

import gin

from trax import fastmath
from trax import layers as tl
from trax.fastmath import numpy as np
Expand All @@ -35,7 +37,7 @@ def forward(self, inputs):
def init_weights_and_state(self, input_signature):
self.weights = np.zeros(input_signature.shape[-1])


@gin.configurable()
def BERTClassifierHead(n_classes):
return tl.Serial([
tl.Select([0], n_in=2),
Expand All @@ -46,7 +48,7 @@ def BERTClassifierHead(n_classes):
tl.LogSoftmax(),
])


@gin.configurable()
def BERTRegressionHead():
return tl.Serial([
tl.Select([0], n_in=2),
Expand Down Expand Up @@ -138,10 +140,10 @@ def __init__(self, *sublayers, init_checkpoint=None):
'Please manually specify the path to bert_model.ckpt')
self.init_checkpoint = init_checkpoint

def new_weights(self, input_signature):
weights = super().new_weights(input_signature)
def init_weights_and_state(self, input_signature):
super().init_weights_and_state(input_signature)
if self.init_checkpoint is None:
return weights
return

print('Loading pre-trained weights from', self.init_checkpoint)
ckpt = tf.train.load_checkpoint(self.init_checkpoint)
Expand Down Expand Up @@ -192,10 +194,9 @@ def reshape_bias(name):
ckpt.get_tensor('bert/pooler/dense/bias'),
]

for a, b in zip(fastmath.tree_leaves(weights), new_w):
for a, b in zip(fastmath.tree_leaves(self.weights), new_w):
assert a.shape == b.shape, (
f'Expected shape {a.shape}, got shape {b.shape}')
weights = jax.tree_unflatten(jax.tree_structure(weights), new_w)
self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w)
move_to_device = jax.jit(lambda x: x)
weights = jax.tree_map(move_to_device, weights)
return weights
self.weights = jax.tree_map(move_to_device, self.weights)
60 changes: 60 additions & 0 deletions trax/supervised/configs/bert_glue.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.lr_schedules
import trax.supervised.pretrain_finetune
import trax.supervised.trainer_lib
import trax.models.research.bert
import trax.layers.metrics

# See https://www.tensorflow.org/datasets/catalog/glue -- valid dataset_name
# values are: glue/cola, glue/sst2, glue/mrpc, glue/qqp, glue/stsb, glue/mnli,
# glue/qnli, glue/rte, glue/wnli. However, training on WNLI with this setup is
# not recommended and will likely result in lower than baseline accuracy.


# Parameters for glue_inputs:
# ==============================================================================
glue_inputs.dataset_name = 'glue/mnli'
glue_inputs.batch_size = 16


# Parameters for bert_tokenizer:
# ==============================================================================
# Download the model from from https://github.com/google-research/bert, type in path to vocabulary and uncomment
# bert_tokenizer.vocab_path = 'PATH-TO-VOCAB'

# Parameters for train:
# ==============================================================================
train.optimizer = @trax.optimizers.Adam
train.eval_frequency = 20
train.eval_steps = 10
train.inputs = @pretrain_finetune.glue_inputs
train.model = @trax.models.BERT
train.steps = 2000
train.checkpoint_highest = 'accuracy'

# Parameters for BERT:
# ==============================================================================
# Download the model from from https://github.com/google-research/bert, type in path to model checkpoint and uncomment
# BERT.init_checkpoint = 'PATH-TO-MODEL-CHECKPOINT'
BERT.head = @bert.BERTClassifierHead
bert.BERTClassifierHead.n_classes = 3

# Parameters for lr_schedules
# ==============================================================================
lr_schedules.multifactor.constant = 3e-5
177 changes: 177 additions & 0 deletions trax/supervised/pretrain_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# coding=utf-8
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""
data processing for BERT.

For now, this file only supports fine-tuning bert-base-uncased on GLUE.
"""
import functools

import gin
import numpy as onp

import tensorflow_datasets as tfds
from trax.data.inputs import Inputs


def _tfds_stream(n_devices, dataset_name, split, batch_size, data_dir,
shuffle_files, shuffle_buffer_size, batch_shuffle_size,
preprocess_fun, repeat=True):
"""Streams batches of examples from tfds, with pure-python preprocessing."""
# TODO(piotrekp1): delete if switched to data_streams
if batch_size % n_devices != 0:
raise ValueError(f'Batch size ({batch_size}) not divisible'
' by number of devices ({n_devices})')
ds = tfds.load(
name=dataset_name, split=split, data_dir=data_dir,
shuffle_files=shuffle_files)
if repeat:
ds = ds.repeat()
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)
ds = ds.batch(batch_size)
if batch_shuffle_size is not None:
ds = ds.shuffle(batch_shuffle_size)

for batch in tfds.as_numpy(ds):
if preprocess_fun is not None:
yield preprocess_fun(batch)
else:
yield batch


@gin.configurable()
def tfds_inputs(
dataset_name,
preprocess_fun,
batch_size,
eval_batch_size=None,
data_dir=None,
train_split=tfds.Split.TRAIN,
eval_split=tfds.Split.VALIDATION,
shuffle_buffer_size=1024,
batch_shuffle_size=128,
):
"""Tensorflow Datasets input pipeline, with pure-python preprocessing."""
if eval_batch_size is None:
eval_batch_size = batch_size
return Inputs(
train_stream=functools.partial(
_tfds_stream,
dataset_name=dataset_name,
split=train_split,
batch_size=batch_size,
data_dir=data_dir,
shuffle_files=True,
shuffle_buffer_size=shuffle_buffer_size,
batch_shuffle_size=batch_shuffle_size,
preprocess_fun=preprocess_fun,
),
eval_stream=functools.partial(
_tfds_stream,
dataset_name=dataset_name,
split=eval_split,
batch_size=eval_batch_size,
data_dir=data_dir,
shuffle_files=False,
shuffle_buffer_size=None,
batch_shuffle_size=None,
preprocess_fun=preprocess_fun,
),
)


@gin.configurable()
def bert_tokenizer(vocab_path=None):
"""Constructs a BERT tokenizer."""
# This import is from https://github.com/google-research/bert which is not
# listed as a dependency in trax.
# TODO(piotrekp1): using SubwordTextEncoder instead after fixing the differences
from bert.tokenization.bert_tokenization import FullTokenizer
if vocab_path is None:
raise ValueError('vocab_path is required to construct the BERT tokenizer.')
tokenizer = FullTokenizer(vocab_path, do_lower_case=True)
return tokenizer


def bert_preprocess(batch, tokenizer, key_a, key_b=None, max_len=128):
"""Tokenize and convert text to model inputs in a BERT format."""
batch_size = batch['idx'].shape[0]
input_ids = onp.zeros((batch_size, max_len), dtype=onp.int32)
type_ids = onp.zeros((batch_size, max_len), dtype=onp.int32)
for i in range(batch_size):
sentence_a = batch[key_a][i]
tokens_a = [101] + tokenizer.convert_tokens_to_ids(
tokenizer.tokenize(sentence_a)) + [102]

if key_b is not None:
sentence_b = batch[key_b][i]
tokens_b = tokenizer.convert_tokens_to_ids(
tokenizer.tokenize(sentence_b)) + [102]
else:
tokens_b = []

ex_input_ids = (tokens_a + tokens_b)[:max_len]
ex_type_ids = ([0] * len(tokens_a) + [1] * len(tokens_b))[:max_len]

input_ids[i, :len(ex_input_ids)] = ex_input_ids
type_ids[i, :len(ex_type_ids)] = ex_type_ids
return input_ids, type_ids, input_ids > 0, batch['label'], onp.ones(batch_size)


@gin.configurable()
def glue_inputs(dataset_name=gin.REQUIRED, batch_size=16, eval_batch_size=None, data_dir=None,
max_len=128, tokenizer=bert_tokenizer):
"""Input pipeline for fine-tuning BERT on GLUE tasks."""
if callable(tokenizer): # If we pass a function, e.g., through gin, call it.
tokenizer = bert_tokenizer()

eval_split = tfds.Split.VALIDATION
if dataset_name == 'glue/mnli':
eval_split = 'validation_matched'
# TODO(kitaev): Support diagnostic dataset (AX)

keys_lookup = {
'glue/cola': ('sentence', None),
'glue/sst2': ('sentence', None),
'glue/mrpc': ('sentence1', 'sentence2'),
'glue/qqp': ('question1', 'question2'),
'glue/stsb': ('sentence1', 'sentence2'),
'glue/mnli': ('premise', 'hypothesis'), # TODO(kitaev): swap the two?
'glue/qnli': ('question', 'sentence'), # TODO(kitaev) swap the two?
'glue/rte': ('sentence1', 'sentence2'),
'glue/wnli': ('sentence1', 'sentence2'),
}

key_a, key_b = keys_lookup[dataset_name]

preprocess_fn = functools.partial(bert_preprocess,
tokenizer=tokenizer,
key_a=key_a,
key_b=key_b,
max_len=max_len)
return tfds_inputs( # TODO(piotrekp1): use data_streams instead
dataset_name=dataset_name,
preprocess_fun=preprocess_fn,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
data_dir=data_dir,
train_split=tfds.Split.TRAIN,
eval_split=eval_split
)

# TODO(piotrekp1): add glue evaluation