diff --git a/parlai/agents/bart/bart.py b/parlai/agents/bart/bart.py index 514f2fc2b8a..0033ac8a984 100644 --- a/parlai/agents/bart/bart.py +++ b/parlai/agents/bart/bart.py @@ -15,7 +15,7 @@ """ import os import torch -from typing import Optional +from typing import Optional, Dict, Any from parlai.agents.bart.convert_fairseq_to_parlai import ConversionScript from parlai.agents.bart.modules import BartModel @@ -24,7 +24,7 @@ from parlai.core.message import Message from parlai.core.opt import Opt from parlai.core.params import ParlaiParser -from parlai.core.torch_agent import Batch, History +from parlai.core.torch_agent import Batch, History, TorchAgent from parlai.utils.typing import TShared from parlai.zoo.bart.build import download, CONVERSION_ARGS, BART_ARGS @@ -53,6 +53,12 @@ def add_cmdline_args(argparser: ParlaiParser): default=None, help='fairseq checkpoint for bart', ) + group.add_argument( + '--output-conversion-path', + type=str, + default=None, + help='where to save fairseq conversion', + ) argparser.set_defaults(dict_tokenizer='gpt2') def __init__(self, opt: Opt, shared: TShared = None): @@ -83,15 +89,15 @@ def _initialize_bart(self, opt: Opt) -> Opt: compare_init_model_opts(opt, opt) return opt - def _convert_model(self, opt: Opt) -> Opt: + def _get_conversion_args(self, opt: Opt) -> Dict[str, Any]: """ - Convert fairseq init model to ParlAI Model. + Get args for fairseq model conversion. :param opt: - options + ParlAI Opt - :return opt: - return opt with new init_model path + :return args: + returns dictionary of args to send to conversion script. """ model_name = os.path.split(opt['init_fairseq_model'])[-1] args = CONVERSION_ARGS.copy() @@ -99,11 +105,26 @@ def _convert_model(self, opt: Opt) -> Opt: args['input'] = [opt['init_fairseq_model']] if opt.get('model_file') and not os.path.exists(opt['model_file']): args['output'] = opt['model_file'] + elif opt.get('output_conversion_path'): + args['output'] = opt['output_conversion_path'] else: args['output'] = os.path.join( opt['datapath'], 'models/converted_fairseq_models/', model_name ) + return args + + def _convert_model(self, opt: Opt) -> Opt: + """ + Convert fairseq init model to ParlAI Model. + + :param opt: + options + + :return opt: + return opt with new init_model path + """ + args = self._get_conversion_args(opt) ConversionScript.main(**args) opt['init_model'] = args['output'] return opt @@ -119,6 +140,14 @@ def build_model(self) -> BartModel: ) return model + def vectorize(self, *args, **kwargs): + """ + Override vectorize for generative models. + """ + kwargs['add_start'] = True # need start token for BART + kwargs['add_end'] = True + return TorchAgent.vectorize(self, *args, **kwargs) + def _set_text_vec( self, obs: Message, history: History, truncate: Optional[int] ) -> Message: @@ -138,7 +167,9 @@ def _set_text_vec( ) return obs - def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device): + def _get_initial_decoder_input( + self, bsz: int, beam_size: int, dev: torch.device + ) -> torch.LongTensor: """ Override to seed decoder with EOS token. diff --git a/parlai/agents/bart/convert_fairseq_to_parlai.py b/parlai/agents/bart/convert_fairseq_to_parlai.py index a3011d6a893..c59c84170fe 100644 --- a/parlai/agents/bart/convert_fairseq_to_parlai.py +++ b/parlai/agents/bart/convert_fairseq_to_parlai.py @@ -132,10 +132,7 @@ def run(self): self.agent.opt.pop('converting', None) self.agent.save(self.opt['output']) # 4. enjoy! - self.agent.observe( - {'text': "What's your favorite kind of ramen?", 'episode_done': False} - ) - print(self.agent.act()) + self.print_agent_act() def get_parlai_opt(self) -> Opt: """ @@ -478,6 +475,15 @@ def convert_model_weight(self, opt: Opt) -> Dict[str, Any]: return_dict['START'] = torch.LongTensor([1]) # type: ignore return return_dict + def print_agent_act(self): + """ + Print a sample act from the converted agent. + """ + self.agent.observe( + {'text': "What's your favorite kind of ramen?", 'episode_done': False} + ) + print(self.agent.act()) + if __name__ == '__main__': ConversionScript.main() diff --git a/parlai/agents/bart/modules.py b/parlai/agents/bart/modules.py index ce60119dc03..495476ab864 100644 --- a/parlai/agents/bart/modules.py +++ b/parlai/agents/bart/modules.py @@ -7,6 +7,7 @@ """ import torch import torch.nn.functional as F +from typing import Tuple, Any from parlai.agents.transformer.modules import TransformerGeneratorModel @@ -25,3 +26,27 @@ def output(self, tensor: torch.Tensor) -> torch.Tensor: # project back to vocabulary output = F.linear(tensor, self.embeddings.weight) return output + + def decode_forced( + self, encoder_states: Tuple[Any, ...], ys: torch.LongTensor + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + Decode with a fixed, true sequence, computing loss. + + Overriding `TGM.decode_forced` to bypass assertion that BOS is not present, and + additionally insert EOS as first token + """ + bsz = ys.size(0) + seqlen = ys.size(1) + inputs = ys.narrow(1, 0, seqlen - 1) + inputs = torch.cat( + [ + torch.LongTensor([self.END_IDX]).detach().expand(bsz, 1).to(inputs), + inputs, + ], + 1, + ) + latent, _ = self.decoder(inputs, encoder_states) + logits = self.output(latent) + _, preds = logits.max(dim=2) + return logits, preds diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index f50ddc87dcf..0a17321f668 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -144,6 +144,20 @@ def __init__( self.register_buffer('START', torch.LongTensor([start_idx])) self.longest_label = longest_label + def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor): + """ + Return initial input to the decoder. + + :param bsz: + batchsize + :param inputs: + inputs to decode + + :return initial_input: + initial input for the decoder. + """ + return torch.cat([self.START.detach().expand(bsz, 1), inputs], 1) + def decode_forced(self, encoder_states, ys): """ Decode with a fixed, true sequence, computing loss. @@ -178,7 +192,7 @@ def decode_forced(self, encoder_states, ys): "your model will have a double BOS token, which is probably not what " "you intended." ) - inputs = torch.cat([self.START.detach().expand(bsz, 1), inputs], 1) + inputs = self._get_initial_forced_decoder_input(bsz, inputs) latent, _ = self.decoder(inputs, encoder_states) logits = self.output(latent) _, preds = logits.max(dim=2) @@ -969,7 +983,9 @@ def _get_context(self, batch, batch_idx): """ return batch.text_vec[batch_idx] - def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device): + def _get_initial_decoder_input( + self, bsz: int, beam_size: int, dev: torch.device + ) -> torch.LongTensor: """ Return initial input to the decoder. @@ -981,7 +997,7 @@ def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device device to send input to. :return initial_input: - initial input for the decoder. + initial input for the decoder """ return ( torch.LongTensor( # type: ignore @@ -991,6 +1007,29 @@ def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device .to(dev) ) + def _get_next_decoder_input( + self, + prev_input: torch.LongTensor, + selection: torch.LongTensor, + incr_state_inds: torch.LongTensor, + ) -> torch.LongTensor: + """ + Return next decoder input. + + :param prev_input: + previous input to decoder + :param selection: + token selections for current timestep + :param inds: + incremental state indices + + :return decoder input: + return decoder input for next timestep + """ + prev_input = torch.index_select(prev_input, 0, incr_state_inds) + decoder_input = torch.cat([prev_input, selection], dim=-1) + return decoder_input + def _generate( self, batch: Batch, @@ -1092,11 +1131,12 @@ def _generate( incr_state = model.reorder_decoder_incremental_state( incr_state, incr_state_inds ) - decoder_input = torch.index_select(decoder_input, 0, incr_state_inds) selection = torch.cat( [b.get_output_from_current_step() for b in beams] ).unsqueeze(-1) - decoder_input = torch.cat([decoder_input, selection], dim=-1) + decoder_input = self._get_next_decoder_input( + decoder_input, selection, incr_state_inds + ) # get all finalized candidates for each sample (and validate them) n_best_beam_preds_scores = [b.get_rescored_finished() for b in beams] diff --git a/tests/nightly/gpu/test_bart.py b/tests/nightly/gpu/test_bart.py index 1c794f21718..143e482b0c5 100644 --- a/tests/nightly/gpu/test_bart.py +++ b/tests/nightly/gpu/test_bart.py @@ -27,9 +27,12 @@ def test_bart(self): learningrate=1, batchsize=4, num_epochs=1, + short_final_eval=True, + validation_max_exs=12, ) ) - self.assertAlmostEqual(test['ppl'], 1.0, places=2) + self.assertLessEqual(valid['ppl'], 11.0) + self.assertLessEqual(test['ppl'], 11.0) if __name__ == '__main__':