From 936f2232e3bb51b725cd55916f80345867498eb1 Mon Sep 17 00:00:00 2001 From: Jude Fernandes Date: Mon, 8 Aug 2022 14:05:14 -0700 Subject: [PATCH 1/2] ran autoformatter --- parlai/core/torch_generator_agent.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index e9afe742caa..97f96998a30 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1050,25 +1050,16 @@ def _treesearch_factory(self, device, verbose=False): else: raise ValueError(f"Can't use inference method {method}") - def _get_context(self, batch, batch_idx): - """ - Set the beam context for n-gram context blocking. - - Intentionally overridable for more complex model histories. - """ - if self.beam_context_block_ngram <= 0: - # We aren't context blocking, return empty tensor - return torch.LongTensor() - - ctxt = batch.text_vec[batch_idx] - if self.beam_block_full_context: - ctxt = batch.full_text_vec[batch_idx] - return ctxt - def _get_batch_context(self, batch): """ Version of TGA._get_context() that operates on full batches for speed. """ + if hasattr(self, '_get_context'): + # Warn users that have subclassed with '_get_gontext + warn_once( + "WARNING: TGA._get_context() is deprecated, use TGA.get_batch_context() instead" + ) + if self.beam_context_block_ngram <= 0: # We aren't context blocking, return empty tensor of the correct size return torch.zeros(batch.batchsize, 0, dtype=torch.long) From 943c35944ee19941dd6beb6416c2bfe47d0389f1 Mon Sep 17 00:00:00 2001 From: Jude Fernandes Date: Mon, 8 Aug 2022 14:11:14 -0700 Subject: [PATCH 2/2] changed warning language from 'deprecated' to 'removed' --- parlai/core/torch_generator_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 97f96998a30..bb3bb3fdefc 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1057,7 +1057,7 @@ def _get_batch_context(self, batch): if hasattr(self, '_get_context'): # Warn users that have subclassed with '_get_gontext warn_once( - "WARNING: TGA._get_context() is deprecated, use TGA.get_batch_context() instead" + "WARNING: TGA._get_context() has been removed, use TGA.get_batch_context() instead" ) if self.beam_context_block_ngram <= 0: