Skip to content

Commit

Permalink
major refactor (reuse-bart)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Sep 7, 2020
1 parent 38cc9c1 commit 226dad1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 169 deletions.
158 changes: 20 additions & 138 deletions src/transformers/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_bart import DecoderLayer, EncoderLayer
from .modeling_bart import (
DecoderLayer,
EncoderLayer,
LayerNorm,
_prepare_bart_decoder_inputs,
_reorder_buffer,
invert_mask,
)
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
from .modeling_utils import PreTrainedModel

Expand Down Expand Up @@ -97,43 +104,36 @@
Here is how to compare BLEU scores against fairseq implementation:
# Note: to match fairseq params you need to set num_beams=50 in
# `configuration_fsmt.py` and lower BS as it'll need more GPU memory
cd examples/seq2seq
# en-ru
export PAIR=en-ru
export DATA_DIR=data/$PAIR
export SAVE_DIR=data/$PAIR
export BS=8
export NUM_BEAMS=50
mkdir -p $DATA_DIR
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
echo $PAIR
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
# (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
# ru-en
export PAIR=ru-en
export DATA_DIR=data/$PAIR
export SAVE_DIR=data/$PAIR
export BS=8
export NUM_BEAMS=50
mkdir -p $DATA_DIR
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
echo $PAIR
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
# de-en
Expand All @@ -142,11 +142,12 @@
export DATA_DIR=data/$PAIR
export SAVE_DIR=data/$PAIR
export BS=8
export NUM_BEAMS=50
mkdir -p $DATA_DIR
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
echo $PAIR
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
# (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
Expand All @@ -162,7 +163,7 @@
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
echo $PAIR
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
# (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
Expand All @@ -171,8 +172,7 @@

FSMT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matters related to general usage and behavior.
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and behavior.
Parameters:
config (:class:`~transformers.FSMTConfig`): Model configuration class with all the parameters of the model.
Expand Down Expand Up @@ -238,33 +238,6 @@
"""


def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
return attention_mask.eq(0)


def _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
pad_token_id = config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()
if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask


class PretrainedFSMTModel(PreTrainedModel):
config_class = FSMTConfig
base_model_prefix = "model"
Expand Down Expand Up @@ -293,36 +266,6 @@ def dummy_inputs(self):
return dummy_inputs


def _make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer


# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
if shape_1 != shape2:
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))


def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens


def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = input_ids.eq(padding_idx)
if not padding_mask.any():
padding_mask = None
return padding_mask


# Helper Modules


Expand Down Expand Up @@ -592,70 +535,11 @@ def forward(
)


def _reorder_buffer(attn_cache, new_order):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order)
return attn_cache


# XXX: remove this and its references
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""

def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
# FSMT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)

def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions + self.offset)


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm

return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)


# Public API
def _get_shape(t):
return getattr(t, "shape", None)


# def output_projection(self):
# return nn.Linear(
# self.embed_tokens.weight.shape[1],
# self.embed_tokens.weight.shape[0],
# bias=False,
# )


@add_start_docstrings(
"The bare FSMT Model outputting raw hidden-states without any specific head on top.",
FSMT_START_DOCSTRING,
Expand Down Expand Up @@ -713,7 +597,7 @@ def forward(

# make masks if user doesn't supply
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
Expand Down Expand Up @@ -772,13 +656,13 @@ def get_input_embeddings(self):
return self.encoder.embed_tokens

def set_input_embeddings(self, value):
self.encoder.embed_tokens = value # self.encoder_embed_tokens = value
self.encoder.embed_tokens = value

def get_output_embeddings(self):
return self.decoder.embed_tokens

def set_output_embeddings(self, value):
self.decoder.embed_tokens = value # self.decoder_embed_tokens = value
self.decoder.embed_tokens = value


@add_start_docstrings(
Expand Down Expand Up @@ -935,8 +819,6 @@ def get_encoder(self):

def get_output_embeddings(self):
return self.model.decoder.embed_tokens
# XXX: it was, but probably is not needed here
# return _make_linear_from_emb(self.decoder.embed_tokens) # make it on the fly


def make_positions(tensor, padding_idx: int):
Expand Down
34 changes: 3 additions & 31 deletions tests/test_modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,8 @@
import torch

from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers.modeling_fsmt import (
SinusoidalPositionalEmbedding,
_prepare_fsmt_decoder_inputs,
invert_mask,
shift_tokens_right,
)
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
from transformers.modeling_bart import _prepare_bart_decoder_inputs, invert_mask
from transformers.modeling_fsmt import SinusoidalPositionalEmbedding


@require_torch
Expand Down Expand Up @@ -164,7 +159,7 @@ def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"]
)
model = FSMTModel(config).to(torch_device).eval()
Expand Down Expand Up @@ -287,15 +282,6 @@ def test_generate_beam_search(self):
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
# TODO(SS): uneven length batches, empty inputs

def test_shift_tokens_right(self):
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
shifted = shift_tokens_right(input_ids, 1)
n_pad_before = input_ids.eq(1).float().sum()
n_pad_after = shifted.eq(1).float().sum()
self.assertEqual(shifted.shape, input_ids.shape)
self.assertEqual(n_pad_after, n_pad_before - 1)
self.assertTrue(torch.eq(shifted[:, 0], 2).all())

def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data()
attention_mask = input_ids.ne(1).to(torch_device)
Expand All @@ -310,20 +296,6 @@ def test_dummy_inputs(self):
model = FSMTForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs)

def test_prepare_fsmt_decoder_inputs(self):
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf")
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids
)
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())

def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data()

Expand Down

0 comments on commit 226dad1

Please sign in to comment.