Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[BART/TGA] BART and TGA API Updates (#2840)
Browse files Browse the repository at this point in the history
* bart updates

* black

* change decode_forced

* minor change

* black

* minor api change

* address eric's comment

* update bart test

* black

* fix test
  • Loading branch information
klshuster authored Aug 4, 2020
1 parent 49442c0 commit 21463cc
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 18 deletions.
47 changes: 39 additions & 8 deletions parlai/agents/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
import os
import torch
from typing import Optional
from typing import Optional, Dict, Any

from parlai.agents.bart.convert_fairseq_to_parlai import ConversionScript
from parlai.agents.bart.modules import BartModel
Expand All @@ -24,7 +24,7 @@
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch, History
from parlai.core.torch_agent import Batch, History, TorchAgent
from parlai.utils.typing import TShared
from parlai.zoo.bart.build import download, CONVERSION_ARGS, BART_ARGS

Expand Down Expand Up @@ -53,6 +53,12 @@ def add_cmdline_args(argparser: ParlaiParser):
default=None,
help='fairseq checkpoint for bart',
)
group.add_argument(
'--output-conversion-path',
type=str,
default=None,
help='where to save fairseq conversion',
)
argparser.set_defaults(dict_tokenizer='gpt2')

def __init__(self, opt: Opt, shared: TShared = None):
Expand Down Expand Up @@ -83,27 +89,42 @@ def _initialize_bart(self, opt: Opt) -> Opt:
compare_init_model_opts(opt, opt)
return opt

def _convert_model(self, opt: Opt) -> Opt:
def _get_conversion_args(self, opt: Opt) -> Dict[str, Any]:
"""
Convert fairseq init model to ParlAI Model.
Get args for fairseq model conversion.
:param opt:
options
ParlAI Opt
:return opt:
return opt with new init_model path
:return args:
returns dictionary of args to send to conversion script.
"""
model_name = os.path.split(opt['init_fairseq_model'])[-1]
args = CONVERSION_ARGS.copy()

args['input'] = [opt['init_fairseq_model']]
if opt.get('model_file') and not os.path.exists(opt['model_file']):
args['output'] = opt['model_file']
elif opt.get('output_conversion_path'):
args['output'] = opt['output_conversion_path']
else:
args['output'] = os.path.join(
opt['datapath'], 'models/converted_fairseq_models/', model_name
)

return args

def _convert_model(self, opt: Opt) -> Opt:
"""
Convert fairseq init model to ParlAI Model.
:param opt:
options
:return opt:
return opt with new init_model path
"""
args = self._get_conversion_args(opt)
ConversionScript.main(**args)
opt['init_model'] = args['output']
return opt
Expand All @@ -119,6 +140,14 @@ def build_model(self) -> BartModel:
)
return model

def vectorize(self, *args, **kwargs):
"""
Override vectorize for generative models.
"""
kwargs['add_start'] = True # need start token for BART
kwargs['add_end'] = True
return TorchAgent.vectorize(self, *args, **kwargs)

def _set_text_vec(
self, obs: Message, history: History, truncate: Optional[int]
) -> Message:
Expand All @@ -138,7 +167,9 @@ def _set_text_vec(
)
return obs

def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device):
def _get_initial_decoder_input(
self, bsz: int, beam_size: int, dev: torch.device
) -> torch.LongTensor:
"""
Override to seed decoder with EOS token.
Expand Down
14 changes: 10 additions & 4 deletions parlai/agents/bart/convert_fairseq_to_parlai.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ def run(self):
self.agent.opt.pop('converting', None)
self.agent.save(self.opt['output'])
# 4. enjoy!
self.agent.observe(
{'text': "What's your favorite kind of ramen?", 'episode_done': False}
)
print(self.agent.act())
self.print_agent_act()

def get_parlai_opt(self) -> Opt:
"""
Expand Down Expand Up @@ -478,6 +475,15 @@ def convert_model_weight(self, opt: Opt) -> Dict[str, Any]:
return_dict['START'] = torch.LongTensor([1]) # type: ignore
return return_dict

def print_agent_act(self):
"""
Print a sample act from the converted agent.
"""
self.agent.observe(
{'text': "What's your favorite kind of ramen?", 'episode_done': False}
)
print(self.agent.act())


if __name__ == '__main__':
ConversionScript.main()
25 changes: 25 additions & 0 deletions parlai/agents/bart/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import torch
import torch.nn.functional as F
from typing import Tuple, Any

from parlai.agents.transformer.modules import TransformerGeneratorModel

Expand All @@ -25,3 +26,27 @@ def output(self, tensor: torch.Tensor) -> torch.Tensor:
# project back to vocabulary
output = F.linear(tensor, self.embeddings.weight)
return output

def decode_forced(
self, encoder_states: Tuple[Any, ...], ys: torch.LongTensor
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""
Decode with a fixed, true sequence, computing loss.
Overriding `TGM.decode_forced` to bypass assertion that BOS is not present, and
additionally insert EOS as first token
"""
bsz = ys.size(0)
seqlen = ys.size(1)
inputs = ys.narrow(1, 0, seqlen - 1)
inputs = torch.cat(
[
torch.LongTensor([self.END_IDX]).detach().expand(bsz, 1).to(inputs),
inputs,
],
1,
)
latent, _ = self.decoder(inputs, encoder_states)
logits = self.output(latent)
_, preds = logits.max(dim=2)
return logits, preds
50 changes: 45 additions & 5 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def __init__(
self.register_buffer('START', torch.LongTensor([start_idx]))
self.longest_label = longest_label

def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
"""
Return initial input to the decoder.
:param bsz:
batchsize
:param inputs:
inputs to decode
:return initial_input:
initial input for the decoder.
"""
return torch.cat([self.START.detach().expand(bsz, 1), inputs], 1)

def decode_forced(self, encoder_states, ys):
"""
Decode with a fixed, true sequence, computing loss.
Expand Down Expand Up @@ -178,7 +192,7 @@ def decode_forced(self, encoder_states, ys):
"your model will have a double BOS token, which is probably not what "
"you intended."
)
inputs = torch.cat([self.START.detach().expand(bsz, 1), inputs], 1)
inputs = self._get_initial_forced_decoder_input(bsz, inputs)
latent, _ = self.decoder(inputs, encoder_states)
logits = self.output(latent)
_, preds = logits.max(dim=2)
Expand Down Expand Up @@ -969,7 +983,9 @@ def _get_context(self, batch, batch_idx):
"""
return batch.text_vec[batch_idx]

def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device):
def _get_initial_decoder_input(
self, bsz: int, beam_size: int, dev: torch.device
) -> torch.LongTensor:
"""
Return initial input to the decoder.
Expand All @@ -981,7 +997,7 @@ def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device
device to send input to.
:return initial_input:
initial input for the decoder.
initial input for the decoder
"""
return (
torch.LongTensor( # type: ignore
Expand All @@ -991,6 +1007,29 @@ def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device
.to(dev)
)

def _get_next_decoder_input(
self,
prev_input: torch.LongTensor,
selection: torch.LongTensor,
incr_state_inds: torch.LongTensor,
) -> torch.LongTensor:
"""
Return next decoder input.
:param prev_input:
previous input to decoder
:param selection:
token selections for current timestep
:param inds:
incremental state indices
:return decoder input:
return decoder input for next timestep
"""
prev_input = torch.index_select(prev_input, 0, incr_state_inds)
decoder_input = torch.cat([prev_input, selection], dim=-1)
return decoder_input

def _generate(
self,
batch: Batch,
Expand Down Expand Up @@ -1092,11 +1131,12 @@ def _generate(
incr_state = model.reorder_decoder_incremental_state(
incr_state, incr_state_inds
)
decoder_input = torch.index_select(decoder_input, 0, incr_state_inds)
selection = torch.cat(
[b.get_output_from_current_step() for b in beams]
).unsqueeze(-1)
decoder_input = torch.cat([decoder_input, selection], dim=-1)
decoder_input = self._get_next_decoder_input(
decoder_input, selection, incr_state_inds
)

# get all finalized candidates for each sample (and validate them)
n_best_beam_preds_scores = [b.get_rescored_finished() for b in beams]
Expand Down
5 changes: 4 additions & 1 deletion tests/nightly/gpu/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def test_bart(self):
learningrate=1,
batchsize=4,
num_epochs=1,
short_final_eval=True,
validation_max_exs=12,
)
)
self.assertAlmostEqual(test['ppl'], 1.0, places=2)
self.assertLessEqual(valid['ppl'], 11.0)
self.assertLessEqual(test['ppl'], 11.0)


if __name__ == '__main__':
Expand Down

0 comments on commit 21463cc

Please sign in to comment.