From 226dad15ca6a9ef4e26178526e878e8fc5c85874 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 6 Sep 2020 21:16:27 -0700 Subject: [PATCH] major refactor (reuse-bart) --- src/transformers/modeling_fsmt.py | 158 ++++-------------------------- tests/test_modeling_fsmt.py | 34 +------ 2 files changed, 23 insertions(+), 169 deletions(-) diff --git a/src/transformers/modeling_fsmt.py b/src/transformers/modeling_fsmt.py index e6f6fea00fb5d0..5ccd41cda02ccc 100644 --- a/src/transformers/modeling_fsmt.py +++ b/src/transformers/modeling_fsmt.py @@ -46,7 +46,14 @@ add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_bart import DecoderLayer, EncoderLayer +from .modeling_bart import ( + DecoderLayer, + EncoderLayer, + LayerNorm, + _prepare_bart_decoder_inputs, + _reorder_buffer, + invert_mask, +) from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput from .modeling_utils import PreTrainedModel @@ -97,43 +104,36 @@ Here is how to compare BLEU scores against fairseq implementation: -# Note: to match fairseq params you need to set num_beams=50 in -# `configuration_fsmt.py` and lower BS as it'll need more GPU memory - -cd examples/seq2seq - # en-ru export PAIR=en-ru export DATA_DIR=data/$PAIR export SAVE_DIR=data/$PAIR export BS=8 +export NUM_BEAMS=50 mkdir -p $DATA_DIR sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target echo $PAIR -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS # (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605) - - # ru-en export PAIR=ru-en export DATA_DIR=data/$PAIR export SAVE_DIR=data/$PAIR export BS=8 +export NUM_BEAMS=50 mkdir -p $DATA_DIR sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target -echo $PAIR -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation - -# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937) +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS +# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937) # de-en @@ -142,11 +142,12 @@ export DATA_DIR=data/$PAIR export SAVE_DIR=data/$PAIR export BS=8 +export NUM_BEAMS=50 mkdir -p $DATA_DIR sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target echo $PAIR -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS # (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750) @@ -162,7 +163,7 @@ sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target echo $PAIR -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS # (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862) @@ -171,8 +172,7 @@ FSMT_START_DOCSTRING = r""" - This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and - refer to the PyTorch documentation for all matters related to general usage and behavior. + This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and behavior. Parameters: config (:class:`~transformers.FSMTConfig`): Model configuration class with all the parameters of the model. @@ -238,33 +238,6 @@ """ -def invert_mask(attention_mask): - """Turns 1->0, 0->1, False->True, True-> False""" - assert attention_mask.dim() == 2 - return attention_mask.eq(0) - - -def _prepare_fsmt_decoder_inputs( - config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 -): - """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if - none are provided. This mimics the default behavior in fairseq. To override it pass in masks. - Note: this is not called during generation - """ - pad_token_id = config.pad_token_id - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) - bsz, tgt_len = decoder_input_ids.size() - if decoder_padding_mask is None: - decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) - else: - decoder_padding_mask = invert_mask(decoder_padding_mask) - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( - dtype=causal_mask_dtype, device=decoder_input_ids.device - ) - return decoder_input_ids, decoder_padding_mask, causal_mask - - class PretrainedFSMTModel(PreTrainedModel): config_class = FSMTConfig base_model_prefix = "model" @@ -293,36 +266,6 @@ def dummy_inputs(self): return dummy_inputs -def _make_linear_from_emb(emb): - vocab_size, emb_size = emb.weight.shape - lin_layer = nn.Linear(vocab_size, emb_size, bias=False) - lin_layer.weight.data = emb.weight.data - return lin_layer - - -# Helper Functions, mostly for making masks -def _check_shapes(shape_1, shape2): - if shape_1 != shape2: - raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) - - -def shift_tokens_right(input_ids, pad_token_id): - """Shift input ids one token to the right, and wrap the last non pad token (usually ).""" - prev_output_tokens = input_ids.clone() - index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) - prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() - prev_output_tokens[:, 1:] = input_ids[:, :-1] - return prev_output_tokens - - -def make_padding_mask(input_ids, padding_idx=1): - """True for pad tokens""" - padding_mask = input_ids.eq(padding_idx) - if not padding_mask.any(): - padding_mask = None - return padding_mask - - # Helper Modules @@ -592,70 +535,11 @@ def forward( ) -def _reorder_buffer(attn_cache, new_order): - for k, input_buffer_k in attn_cache.items(): - if input_buffer_k is not None: - attn_cache[k] = input_buffer_k.index_select(0, new_order) - return attn_cache - - -# XXX: remove this and its references -class LearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - Padding ids are ignored by either offsetting based on padding_idx - or by setting padding_idx to None and ensuring that the appropriate - position ids are passed to the forward function. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): - # FSMT is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models dont have this hack - self.offset = offset - assert padding_idx is not None - num_embeddings += offset - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) - - def forward(self, input_ids, use_cache=False): - """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - if use_cache: - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing - else: - # starts at 0, ends at 1-seq_len - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) - return super().forward(positions + self.offset) - - -def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): - if torch.cuda.is_available(): - try: - from apex.normalization import FusedLayerNorm - - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) - - -def fill_with_neg_inf(t): - """FP16-compatible function that fills a input_ids with -inf.""" - return t.float().fill_(float("-inf")).type_as(t) - - # Public API def _get_shape(t): return getattr(t, "shape", None) -# def output_projection(self): -# return nn.Linear( -# self.embed_tokens.weight.shape[1], -# self.embed_tokens.weight.shape[0], -# bias=False, -# ) - - @add_start_docstrings( "The bare FSMT Model outputting raw hidden-states without any specific head on top.", FSMT_START_DOCSTRING, @@ -713,7 +597,7 @@ def forward( # make masks if user doesn't supply if not use_cache: - decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs( + decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, @@ -772,13 +656,13 @@ def get_input_embeddings(self): return self.encoder.embed_tokens def set_input_embeddings(self, value): - self.encoder.embed_tokens = value # self.encoder_embed_tokens = value + self.encoder.embed_tokens = value def get_output_embeddings(self): return self.decoder.embed_tokens def set_output_embeddings(self, value): - self.decoder.embed_tokens = value # self.decoder_embed_tokens = value + self.decoder.embed_tokens = value @add_start_docstrings( @@ -935,8 +819,6 @@ def get_encoder(self): def get_output_embeddings(self): return self.model.decoder.embed_tokens - # XXX: it was, but probably is not needed here - # return _make_linear_from_emb(self.decoder.embed_tokens) # make it on the fly def make_positions(tensor, padding_idx: int): diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index 1304a112957d5d..66cd456907f9b2 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -30,13 +30,8 @@ import torch from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer - from transformers.modeling_fsmt import ( - SinusoidalPositionalEmbedding, - _prepare_fsmt_decoder_inputs, - invert_mask, - shift_tokens_right, - ) -PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" + from transformers.modeling_bart import _prepare_bart_decoder_inputs, invert_mask + from transformers.modeling_fsmt import SinusoidalPositionalEmbedding @require_torch @@ -164,7 +159,7 @@ def test_advanced_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.use_cache = False inputs_dict["input_ids"][:, -2:] = config.pad_token_id - decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( + decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"] ) model = FSMTModel(config).to(torch_device).eval() @@ -287,15 +282,6 @@ def test_generate_beam_search(self): self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length)) # TODO(SS): uneven length batches, empty inputs - def test_shift_tokens_right(self): - input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long() - shifted = shift_tokens_right(input_ids, 1) - n_pad_before = input_ids.eq(1).float().sum() - n_pad_after = shifted.eq(1).float().sum() - self.assertEqual(shifted.shape, input_ids.shape) - self.assertEqual(n_pad_after, n_pad_before - 1) - self.assertTrue(torch.eq(shifted[:, 0], 2).all()) - def test_generate_fp16(self): config, input_ids, batch_size = self._get_config_and_data() attention_mask = input_ids.ne(1).to(torch_device) @@ -310,20 +296,6 @@ def test_dummy_inputs(self): model = FSMTForConditionalGeneration(config).eval().to(torch_device) model(**model.dummy_inputs) - def test_prepare_fsmt_decoder_inputs(self): - config, *_ = self._get_config_and_data() - input_ids = _long_tensor(([4, 4, 2])) - decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) - ignore = float("-inf") - decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( - config, input_ids, decoder_input_ids - ) - expected_causal_mask = torch.tensor( - [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad - ).to(input_ids.device) - self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size()) - self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all()) - def test_resize_tokens_embeddings_more(self): config, input_ids, _ = self._get_config_and_data()