diff --git a/.circleci/config.yml b/.circleci/config.yml index 3b529d21a8d..15d54cc64ec 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -235,26 +235,26 @@ commands: - setupcuda - fixgit - restore_cache: - key: deps-20210722-<< parameters.cachename >>-{{ checksum "requirements.txt" }} + key: deps-2021222-<< parameters.cachename >>-{{ checksum "requirements.txt" }} - setup - installdeps - << parameters.more_installs >> - save_cache: - key: deps-20210722-<< parameters.cachename >>-{{ checksum "requirements.txt" }} + key: deps-2021222-<< parameters.cachename >>-{{ checksum "requirements.txt" }} paths: - "~/venv/bin" - "~/venv/lib" - findtests: marker: << parameters.marker >> - restore_cache: - key: data-20210722-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} + key: data-2021222-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} - run: name: Run tests no_output_timeout: 60m command: | coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml - save_cache: - key: data-20210722-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} + key: data-2021222-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} paths: - "~/ParlAI/data" - codecov @@ -271,12 +271,12 @@ commands: - checkout - fixgit - restore_cache: - key: deps-20210722-bw-{{ checksum "requirements.txt" }} + key: deps-2021222-bw-{{ checksum "requirements.txt" }} - setup - installdeps - installtorchgpu17 - save_cache: - key: deps-20210722-bw-{{ checksum "requirements.txt" }} + key: deps-2021222-bw-{{ checksum "requirements.txt" }} paths: - "~/venv/bin" - "~/venv/lib" diff --git a/conftest.py b/conftest.py index fcd15be8248..adc55819d22 100644 --- a/conftest.py +++ b/conftest.py @@ -105,7 +105,7 @@ def pytest_collection_modifyitems(config, items): class_mapping[class_name].append(item) test_groupings = list(class_mapping.keys()) - random.Random(1337).shuffle(test_groupings) + random.Random(1339).shuffle(test_groupings) filtered_tests = filter_tests_with_circleci(test_groupings) new_items = [] diff --git a/parlai/agents/hugging_face/t5.py b/parlai/agents/hugging_face/t5.py index b95257c298e..90ec7a1e336 100644 --- a/parlai/agents/hugging_face/t5.py +++ b/parlai/agents/hugging_face/t5.py @@ -31,7 +31,7 @@ def check_hf_version(v: Tuple[int, int]) -> bool: """ - Check that HF version is greater than 4.3 + Check that HF version is greater than 4.3. """ main, sub = v return main > 4 or (main == 4 and sub >= 3) @@ -195,7 +195,7 @@ def _generate( generation_params.update(overrides) outputs = self.model.t5.generate(**generation_params) - outputs = [(outputs[i], 0) for i in range(outputs.size(0))] + outputs = [(outputs[i], 0, None) for i in range(outputs.size(0))] return outputs, [] diff --git a/parlai/agents/rag/model_types.py b/parlai/agents/rag/model_types.py index 003a1ef18b0..09b13f7ccb7 100644 --- a/parlai/agents/rag/model_types.py +++ b/parlai/agents/rag/model_types.py @@ -203,8 +203,10 @@ def rerank_beams( self, model: RagModel, batch: Batch, - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], - ) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]: + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], + ) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]: """ Optionally rerank beams. """ @@ -423,8 +425,10 @@ def rerank_beams( self, model: RagModel, batch: Batch, - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], - ) -> List[List[List[torch.LongTensor]]]: + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], + ) -> List[List[Tuple[torch.LongTensor, Optional[Dict]]]]: """ Rerank beams in RAG-Sequence, accounting for document probabilities as well. @@ -517,7 +521,9 @@ def augment_batch_for_generation(self, batch: Batch, model: RagModel) -> Batch: def fast_generation( cls, doc_indices: List[int], - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], doc_log_probs: torch.Tensor, n_docs: int, ): @@ -534,22 +540,26 @@ def fast_generation( number of docs per example :return sorted_hyps: - return list of (hyp, score) tuples, sorted by their score. + return list of (hyp, score, token metadata) tuples, sorted by their score. """ - marginalized_hypos: Dict[str, List[torch.Tensor]] = {} + marginalized_hypos: Dict[ + str, Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]] + ] = {} for doc_idx in doc_indices: doc_hypos = n_best_beam_preds_scores[doc_idx] doc_score = doc_log_probs[doc_idx % n_docs] - for hypo, hypo_score in doc_hypos: + for hypo, hypo_score, token_metadata in doc_hypos: score = hypo_score + doc_score hypo_tokens = str(hypo.tolist()) if hypo_tokens in marginalized_hypos: marginalised_hypo = marginalized_hypos[hypo_tokens] - marginalised_hypo[1] = torch.log( - marginalised_hypo[1].exp() + score.exp() + marginalised_hypo = ( + marginalised_hypo[0], + torch.log(marginalised_hypo[1].exp() + score.exp()), + marginalised_hypo[2], ) else: - marginalized_hypos[hypo_tokens] = [hypo, score] + marginalized_hypos[hypo_tokens] = (hypo, score, token_metadata) sorted_by_score = sorted(marginalized_hypos.values(), key=lambda h: -h[1]) return sorted_by_score @@ -560,7 +570,7 @@ def thorough_generation( new_input: torch.LongTensor, null_idx: int, model: RagModel, - ) -> List[Tuple[torch.LongTensor, torch.Tensor]]: + ) -> List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]: """ Apply RAG-sequence thorough generation for a single batch item. @@ -572,7 +582,7 @@ def thorough_generation( input for the model :return sorted_hyps: - return list of (hyp, score) tuples, sorted by their score. + return list of (hyp, score, token_metadata) tuples, sorted by their score. """ # deduplicate, exclude BOS Token hyps = list({str(h.tolist()): h[1:] for h in hyps}.values()) # type: ignore @@ -586,7 +596,7 @@ def thorough_generation( new_ys.unsqueeze(1).unsqueeze(-1), scores.unsqueeze(1), null_idx ) # type: ignore sorted_by_score = [ - (hyps[idx], loss[idx]) for idx in loss.sort()[-1] + (hyps[idx], loss[idx], None) for idx in loss.sort()[-1] ] # sort ascending return sorted_by_score @@ -834,8 +844,8 @@ def reorder_decoder_incremental_state( """ For RAG Token, we send each decoder input through n_docs times. - Similarly to reordering the encoder states, we need to reorder according - to the documents dimensions. + Similarly to reordering the encoder states, we need to reorder according to the + documents dimensions. """ assert incremental_state is not None incremental_state = fix_incremental_state( @@ -866,8 +876,10 @@ def rerank_beams( self, model: RagModel, batch: Batch, - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], - ) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]: + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], + ) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]: """ We don't re-rank beams for RAG Token. """ @@ -1169,8 +1181,10 @@ def rerank_beams( self, model: RagModel, batch: Batch, - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], - ) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]: + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], + ) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]: """ Re-rank beams. @@ -1183,7 +1197,9 @@ def rerank_beams( Thorough decoding is identical RAG Sequence. """ - new_n_best: List[List[Tuple[torch.LongTensor, torch.Tensor]]] = [] + new_n_best: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ] = [] if self.turn_marginalize == 'doc_only' and not self.thorough: # no doc log probs here; just re-sorting beams input_turns_cnt = batch.input_turns_cnt @@ -1197,6 +1213,7 @@ def rerank_beams( new_beam = ( beam[0], beam[1] * self.discount_factor ** (it - i - 1), + beam[2], ) n_best_i.append(new_beam) new_n_best.append(sorted(n_best_i, key=lambda x: -x[1])) diff --git a/parlai/agents/rag/rag.py b/parlai/agents/rag/rag.py index c0c3364934c..2c30980b54b 100644 --- a/parlai/agents/rag/rag.py +++ b/parlai/agents/rag/rag.py @@ -86,7 +86,9 @@ def _generate( beam_size: int, max_ts: int, prefix_tokens: Optional[torch.LongTensor] = None, - ) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]: + ) -> Tuple[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch] + ]: """ Override since T5 needs to call TGA generate. """ @@ -664,7 +666,9 @@ def _generate( beam_size: int, max_ts: int, prefix_tokens: Optional[torch.LongTensor] = None, - ) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]: + ) -> Tuple[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch] + ]: """ Override TGA._generate to potentially call ReGReT. @@ -674,7 +678,7 @@ def _generate( beam_preds_scores, _ = self._regret_generate( batch, beam_size, self.regret_intermediate_maxlen, prefix_tokens ) - preds, _ = zip(*beam_preds_scores) + preds, _, _ = zip(*beam_preds_scores) new_batch = self._regret_rebatchify(batch, preds) # type: ignore gen_outs = self._rag_generate(new_batch, beam_size, max_ts, prefix_tokens) else: @@ -685,8 +689,10 @@ def _generate( def _rerank_beams( self, batch: Batch, - n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]], - ) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]: + n_best_beam_preds_scores: List[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]] + ], + ) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]: """ Optional rerank beams, according to RAG Model type. @@ -710,7 +716,9 @@ def _rag_generate( beam_size: int, max_ts: int, prefix_tokens: Optional[torch.LongTensor] = None, - ) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]: + ) -> Tuple[ + List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch] + ]: """ Separate from _generate to handle regret. """ @@ -887,7 +895,7 @@ def get_model_output(self, batch: Batch) -> Tuple[Any, ...]: beam_preds_scores, beams = self._regret_generate( batch, self.beam_size, self.regret_intermediate_maxlen ) - regret_preds, _ = zip(*beam_preds_scores) + regret_preds, _, _ = zip(*beam_preds_scores) new_batch = self._regret_rebatchify(batch, regret_preds) # type: ignore regret_model_output = self.model( *self._model_input(new_batch), ys=batch.label_vec diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 0c16ba4779c..baa0fb33e78 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -758,7 +758,7 @@ def train_step(self, batch): # out of sync! catch up with the other workers self._fake_forward_backward_pass() - def _construct_token_losses(self, labels, model_output): + def _construct_label_token_losses(self, labels, model_output): # Get non-aggregated losses scores, _, _ = model_output score_view = scores.reshape(-1, scores.size(-1)) @@ -777,6 +777,14 @@ def _construct_token_losses(self, labels, model_output): ) return token_losses + def _construct_generated_token_details(self, tokens, tokens_metadata): + tokens_as_txt = [self.dict[int(token)] for token in tokens] + + details_lists = [tokens_as_txt, tokens_metadata["logprobs"].tolist()] + if "ranks" in tokens_metadata: + details_lists.append(tokens_metadata["ranks"].tolist()) + return list(zip(*details_lists)) + def _compute_fairseq_bleu(self, batch: Batch, preds): """ Compute BLEU score between text and label, using the FAIRSeq BLEU Scorer. @@ -857,15 +865,17 @@ def eval_step(self, batch): self.model.eval() cand_scores = None token_losses = None + text_token_info = None if batch.label_vec is not None: # calculate loss on targets with teacher forcing loss, model_output = self.compute_loss(batch, return_output=True) if self.output_token_losses: - token_losses = self._construct_token_losses( + token_losses = self._construct_label_token_losses( batch.label_vec, model_output ) + beam_preds_scores = None preds = None if self.skip_generation: warn_once("--skip-generation true produces limited metrics") @@ -875,15 +885,25 @@ def eval_step(self, batch): beam_preds_scores, beams = self._generate( batch, self.beam_size, maxlen, prefix_tokens=prefix_tokens ) - preds, scores = zip(*beam_preds_scores) + preds, _, _ = zip(*beam_preds_scores) self._add_generation_metrics(batch, preds) # bsz x beamsize beam_texts: List[List[Tuple[str, float]]] = [] + beam_texts_token_info: List[List[List[Tuple]]] = [] for beam in beams: beam_texts.append([]) - for tokens, score in beam.get_rescored_finished(): + if self.output_token_losses: + beam_texts_token_info.append([]) + + for tokens, score, token_metadata in beam.get_rescored_finished(): try: + if self.output_token_losses: + beam_texts_token_info[-1].append( + self._construct_generated_token_details( + tokens, token_metadata + ) + ) beam_texts[-1].append((self._v2t(tokens), score.item())) except KeyError: logging.error("Decoding error: %s", tokens) @@ -894,18 +914,31 @@ def eval_step(self, batch): if self.rank_candidates: cand_choices, cand_scores = self.rank_eval_label_candidates(batch, bsz) - text = [self._v2t(p) for p in preds] if preds is not None else None + text = ( + [self._v2t(pred_data[0]) for pred_data in beam_preds_scores] + if beam_preds_scores is not None + else None + ) + + if self.output_token_losses and beam_preds_scores is not None: + text_token_info = [] + for beam_text_token_info in beam_texts_token_info: + text_token_info.append(beam_text_token_info[0]) + if text and self.compute_tokenized_bleu: # compute additional bleu scores self._compute_fairseq_bleu(batch, preds) retval = Output( text, cand_choices, token_losses=token_losses, cand_scores=cand_scores ) + if not self.skip_generation: retval.beam_texts = beam_texts + retval.beam_texts_token_info = beam_texts_token_info + retval.text_token_info = text_token_info return retval - def _treesearch_factory(self, device): + def _treesearch_factory(self, device, verbose=False): method = self.opt.get('inference', 'greedy') beam_size = self.opt.get('beam_size', 1) if method == 'greedy': @@ -919,6 +952,7 @@ def _treesearch_factory(self, device): bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, ) elif method == 'beam': return BeamSearch( @@ -931,6 +965,7 @@ def _treesearch_factory(self, device): bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, ) elif method == 'delayedbeam': return DelayedBeamSearch( @@ -945,6 +980,7 @@ def _treesearch_factory(self, device): bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, ) elif method == 'topk': return TopKSampling( @@ -958,6 +994,7 @@ def _treesearch_factory(self, device): bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, ) elif method == 'nucleus': return NucleusSampling( @@ -971,6 +1008,7 @@ def _treesearch_factory(self, device): bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, ) else: raise ValueError(f"Can't use inference method {method}") @@ -1083,7 +1121,7 @@ def _generate( :return: tuple (beam_pred_scores, beams) - - beam_preds_scores: list of (prediction, score) pairs for each sample in + - beam_preds_scores: list of (prediction, score, token_metadata) tuples for each sample in Batch - beams :list of Beam instances defined in Beam class, can be used for any following postprocessing, e.g. dot logging. @@ -1103,13 +1141,16 @@ def _generate( batchsize = batch.batchsize batch_context_list = self._get_batch_context(batch).tolist() beams = [ - self._treesearch_factory(dev) + self._treesearch_factory(dev, verbose=self.output_token_losses) .set_batch_context(batch_context_list, batch_idx) .set_block_list(self.beam_block_list) for batch_idx in range(batchsize) ] else: - beams = [self._treesearch_factory(dev) for _ in range(bsz)] + beams = [ + self._treesearch_factory(dev, verbose=self.output_token_losses) + for _ in range(bsz) + ] # repeat encoder outputs and decoder inputs decoder_input = self._get_initial_decoder_input(bsz, beam_size, dev) @@ -1202,13 +1243,32 @@ class _HypothesisTail(object): """ # use slots because we don't want dynamic attributes here - __slots__ = ['timestep', 'hypid', 'score', 'tokenid'] + __slots__ = ['timestep', 'hypid', 'score', 'tokenid', 'token_score', 'token_rank'] - def __init__(self, timestep, hypid, score, tokenid): + def __init__(self, timestep, hypid, score, tokenid, token_score, token_rank): self.timestep = timestep self.hypid = hypid self.score = score self.tokenid = tokenid + self.token_score = token_score + self.token_rank = token_rank + + +class _PathSelection(object): + """ + Output of TreeSearch:select_paths. + + Represents output of path selection process. + """ + + __slots__ = ['hypothesis_ids', 'token_ids', 'scores', 'token_scores', 'token_ranks'] + + def __init__(self, hypothesis_ids, token_ids, scores, token_scores, token_ranks): + self.hypothesis_ids = hypothesis_ids + self.token_ids = token_ids + self.scores = scores + self.token_scores = token_scores + self.token_ranks = token_ranks class TreeSearch(object): @@ -1231,6 +1291,7 @@ def __init__( min_length=3, device='cpu', length_penalty=0.65, + verbose=False, ): """ Instantiate Beam object. @@ -1273,6 +1334,18 @@ def __init__( self.outputs = [ torch.Tensor(self.beam_size).long().fill_(self.bos).to(self.device) ] + + self.verbose = verbose + self.token_scores, self.token_ranks = None, None + if self.verbose: + # beam token scores + self.token_scores = torch.zeros( + (self.beam_size, 1), device=self.device + ) # log prob of bos token is 0 + # beam token ranks + self.token_ranks = torch.ones( + (self.beam_size, 1), device=self.device + ) # bos token is prob 1 and so rank 1 # keeps tuples (score, time_step, hyp_id) self.finished = [] self.eos_top = False @@ -1324,7 +1397,7 @@ def get_backtrack_from_current_step(self): return self.bookkeep[-1] @abstractmethod - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ Select the next vocabulary item in these beams. @@ -1337,7 +1410,7 @@ def select_paths(self, logprobs, prior_scores, current_length): :param current_length: the current length in tokens :return: - a (hypothesis_ids, token_id, scores) tuple, where: + a {hypothesis_ids, token_ids, scores, token_scores, token_ranks} , where: - hypothesis_ids is a LongTensor of hypotheses we're extending. May have repeats, but should always be (beamsize) long. @@ -1345,6 +1418,10 @@ def select_paths(self, logprobs, prior_scores, current_length): each of the hypotheses. - scores is a (beamsize) Tensor with the updated cumulative log-probs of each beam. + - token_scores is a (beamsize) Tensor with the log-probs of the next-token choices for + each of the hypotheses. + - token_ranks is a (beamsize) Tensor with the ranks of the next-token choices for + each of the hypotheses. """ pass @@ -1420,21 +1497,28 @@ def advance(self, logprobs): self.context_block_ngram, logprobs, self.context ) - hyp_ids, tok_ids, self.scores = self.select_paths( - logprobs, self.scores, current_length - ) + path_selection = self.select_paths(logprobs, self.scores, current_length) + self.scores = path_selection.scores # use clone() here to ensure that self.all_scores will not be changed # later due to any penalties to self.scores self.all_scores.append(self.scores.clone()) - self.outputs.append(tok_ids) - self.bookkeep.append(hyp_ids) - tok_id_list = tok_ids.tolist() + self.outputs.append(path_selection.token_ids) + self.bookkeep.append(path_selection.hypothesis_ids) + tok_id_list = path_selection.token_ids.tolist() self.partial_hyps = [ - self.partial_hyps[hyp_ids[i]] + [tok_id_list[i]] + self.partial_hyps[path_selection.hypothesis_ids[i]] + [tok_id_list[i]] for i in range(self.beam_size) ] + if self.verbose: + self.token_scores = torch.cat( + (self.token_scores, path_selection.token_scores.unsqueeze(1)), dim=1 + ) + self.token_ranks = torch.cat( + (self.token_ranks, path_selection.token_ranks.unsqueeze(1)), dim=1 + ) + # check new hypos for eos label, if we have some, add to finished for hypid in range(self.beam_size): if self.outputs[-1][hypid] == self.eos: @@ -1446,6 +1530,12 @@ def advance(self, logprobs): hypid=hypid, score=self.all_scores[-1][hypid], tokenid=self.eos, + token_score=self.token_scores[hypid, -1] + if self.token_scores is not None + else None, + token_rank=self.token_ranks[hypid, -1] + if self.token_ranks is not None + else None, ) self.finished.append(eostail) self.n_best_counter += 1 @@ -1482,6 +1572,7 @@ def _get_hyp_from_finished(self, hypothesis_tail): """ hyp_idx = [] endback = hypothesis_tail.hypid + for i in range(hypothesis_tail.timestep, -1, -1): hyp_idx.append( _HypothesisTail( @@ -1489,6 +1580,12 @@ def _get_hyp_from_finished(self, hypothesis_tail): hypid=endback, score=self.all_scores[i][endback], tokenid=self.outputs[i][endback], + token_score=self.token_scores[endback, i] + if self.token_scores is not None + else None, + token_rank=self.token_ranks[endback, i] + if self.token_ranks is not None + else None, ) ) endback = self.bookkeep[i - 1][endback] @@ -1501,6 +1598,18 @@ def _get_pretty_hypothesis(self, list_of_hypotails): """ return torch.stack([ht.tokenid for ht in reversed(list_of_hypotails)]) + def _get_pretty_token_metadata(self, list_of_hypotails): + """ + Return token probabilities and ranks as two tensors. + """ + + return { + "logprobs": torch.stack( + [ht.token_score for ht in reversed(list_of_hypotails)] + ), + "ranks": torch.stack([ht.token_rank for ht in reversed(list_of_hypotails)]), + } + def get_rescored_finished(self, n_best=None): """ Return finished hypotheses according to adjusted scores. @@ -1512,9 +1621,12 @@ def get_rescored_finished(self, n_best=None): number of finalized hypotheses to return :return: - list of (tokens, score) pairs, in sorted order, where: + list of (tokens, score, token_metadata) 3-tuples, in sorted order, where: - tokens is a tensor of token ids - score is the adjusted log probability of the entire utterance + - token_metadata dictionary: + token_logprobs -> a tensor of conditional log probabilities of tokens + token_ranks -> a tensor of ranks of tokens in vocabulator, by probability, when sampled """ # if we never actually finished, force one if not self.finished: @@ -1525,6 +1637,12 @@ def get_rescored_finished(self, n_best=None): hypid=0, score=self.all_scores[-1][0], tokenid=self.outputs[-1][0], + token_score=self.token_scores[0, -1] + if self.token_scores is not None + else None, + token_rank=self.token_ranks[0, -1] + if self.token_ranks is not None + else None, ) ) @@ -1539,6 +1657,8 @@ def get_rescored_finished(self, n_best=None): hypid=finished_item.hypid, score=finished_item.score / length_penalty, tokenid=finished_item.tokenid, + token_score=finished_item.token_score, + token_rank=finished_item.token_rank, ) ) @@ -1548,17 +1668,21 @@ def get_rescored_finished(self, n_best=None): if n_best is not None: srted = srted[:n_best] - n_best_list = [ - (self._get_pretty_hypothesis(self._get_hyp_from_finished(hyp)), hyp.score) - for hyp in srted - ] + n_best_list = [] + for hyp in srted: + hyp_data = self._get_hyp_from_finished(hyp) + token_ids = self._get_pretty_hypothesis(hyp_data) + token_metadata = ( + self._get_pretty_token_metadata(hyp_data) if self.verbose else None + ) + n_best_list.append((token_ids, hyp.score, token_metadata)) # check that there is at least one finished candidate # and assert that each of them contains only one EOS assert ( len(n_best_list) >= 1 ), f'TreeSearch returned {len(n_best_list)} candidates, must be >= 1' - for (pred, score) in n_best_list: + for (pred, score, _) in n_best_list: assert (pred == self.eos).sum() == 1, ( f'TreeSearch returned a finalized hypo with multiple end tokens ' f'with score {score.item():.2f}' @@ -1580,11 +1704,23 @@ def __init__(self, *args, **kwargs): if self.beam_size != 1: raise ValueError('Greedy search can only be run with beam size 1.') - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: tok_scores, tok_ids = logprobs.max(1) best_scores = tok_scores + prior_scores - hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device) - return (hyp_ids, tok_ids, best_scores) + hyp_ids = torch.arange(logprobs.size(0), device=logprobs.device) + + tok_ranks = None + if self.verbose: + tok_scores = tok_scores.view(-1) + tok_ranks = torch.tensor([0], device=logprobs.device, dtype=torch.long) + + return _PathSelection( + hypothesis_ids=hyp_ids, + token_ids=tok_ids, + scores=best_scores, + token_scores=tok_scores, + token_ranks=tok_ranks, + ) class BeamSearch(TreeSearch): @@ -1592,7 +1728,7 @@ class BeamSearch(TreeSearch): Beam search. """ - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ Select the next vocabulary item in these beams. """ @@ -1611,7 +1747,28 @@ def select_paths(self, logprobs, prior_scores, current_length): # get the actual word id from residual of the same division tok_ids = best_idxs % voc_size - return (hyp_ids, tok_ids, best_scores) + tok_scores, tok_ranks = None, None + if self.verbose: + tok_scores = ( + torch.index_select(logprobs, 0, hyp_ids) + .gather(1, tok_ids.unsqueeze(1)) + .view(-1) + ) + + tok_ranks = ( + logprobs.argsort(1, descending=True) + .argsort(1) + .view(-1) + .gather(0, best_idxs) + ) + + return _PathSelection( + hypothesis_ids=hyp_ids, + token_ids=tok_ids, + scores=best_scores, + token_scores=tok_scores, + token_ranks=tok_ranks, + ) class DelayedBeamSearch(TreeSearch): @@ -1630,7 +1787,7 @@ def __init__(self, k, delay, *args, **kwargs): self.k = k self.delay = delay - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: if current_length < self.delay: return TopKSampling.select_paths( self, logprobs, prior_scores, current_length @@ -1655,7 +1812,7 @@ def __init__(self, k, *args, **kwargs): super().__init__(*args, **kwargs) self.k = k - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: values, indices = logprobs.topk(self.k, dim=-1) probs = torch.softmax(values, dim=-1) choices = torch.multinomial(probs, 1)[:, 0] @@ -1663,7 +1820,19 @@ def select_paths(self, logprobs, prior_scores, current_length): tok_ids = indices[hyp_ids, choices] scores = values[hyp_ids, choices] best_scores = prior_scores.expand_as(scores) + scores - return (hyp_ids, tok_ids, best_scores) + + tok_scores, tok_ranks = None, None + if self.verbose: + tok_scores = scores.view(-1) + tok_ranks = choices.view(-1) + + return _PathSelection( + hypothesis_ids=hyp_ids, + token_ids=tok_ids, + scores=best_scores, + token_scores=tok_scores, + token_ranks=tok_ranks, + ) class NucleusSampling(TreeSearch): @@ -1682,7 +1851,7 @@ def __init__(self, p, *args, **kwargs): super().__init__(*args, **kwargs) self.p = p - def select_paths(self, logprobs, prior_scores, current_length): + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: # Unlike the other treesearch methods, we have to switch to linspace # for the probabilities in order to compute the CDF. probs = torch.softmax(logprobs, dim=-1) @@ -1698,4 +1867,16 @@ def select_paths(self, logprobs, prior_scores, current_length): # Convert back to logspace. scores = sprobs[hyp_ids, choices].log() best_scores = prior_scores.expand_as(scores) + scores - return (hyp_ids, tok_ids, best_scores) + + tok_scores, tok_ranks = None, None + if self.verbose: + tok_scores = scores.view(-1) + tok_ranks = choices.view(-1) + + return _PathSelection( + hypothesis_ids=hyp_ids, + token_ids=tok_ids, + scores=best_scores, + token_scores=tok_scores, + token_ranks=tok_ranks, + ) diff --git a/parlai/utils/misc.py b/parlai/utils/misc.py index 66435c69693..3f882e4e303 100644 --- a/parlai/utils/misc.py +++ b/parlai/utils/misc.py @@ -43,6 +43,7 @@ 'text_candidates', 'reward', 'token_losses', + 'generated_text_token_info', 'metrics', } @@ -522,6 +523,29 @@ def _token_losses_line( ) return _pretty_lines(space, key, formatted_tl, 'text2') + def _text_token_info_line( + msg: Dict[str, Any], fields_to_show: List[str], space: str + ) -> Optional[str]: + """ + Displays the loss associated with each token. Can be used for debugging + generative models. + + See TorchGeneratorAgent._generate for an example implementation. + """ + key = 'text_token_info' + text_token_info = msg.get(key, None) + + if key not in fields_to_show or not text_token_info: + return None + # Reduce losses to 4 significant figures + formatted_tl = ' | '.join( + [ + f"{tl[0]} {float('{:.4g}'.format(tl[1]))} {tl[2]}" + for tl in text_token_info + ] + ) + return _pretty_lines(space, key, formatted_tl, 'text2') + def _pretty_lines(indent_space, field, value, style): line = '{}{} {}'.format( indent_space, colorize('[' + field + ']:', 'field'), colorize(value, style) @@ -616,6 +640,10 @@ def _pretty_lines(indent_space, field, value, style): if token_loss_line: lines.append(token_loss_line) + text_token_info_line = _text_token_info_line(msg, fields_to_show, space) + if text_token_info_line: + lines.append(text_token_info_line) + if episode_done: lines.append( colorize('- - - - - - - END OF EPISODE - - - - - - - - - -', 'highlight') diff --git a/projects/dialogue_unlikelihood/agents.py b/projects/dialogue_unlikelihood/agents.py index db02f212966..4135434c857 100644 --- a/projects/dialogue_unlikelihood/agents.py +++ b/projects/dialogue_unlikelihood/agents.py @@ -180,7 +180,7 @@ def compute_loss(self, batch, return_output=False): beam_pred_scores, _ = self._generate(batch, self.beam_size, maxlen) # forward pass to create graph for beam search case - generations = [g[1:] for (g, s) in beam_pred_scores] + generations = [g[1:] for (g, s, _) in beam_pred_scores] pred_toks = torch.nn.utils.rnn.pad_sequence(generations, batch_first=True) model_output = self.model(*self._model_input(batch), ys=pred_toks) logits, preds, _ = model_output @@ -418,7 +418,7 @@ def compute_loss(self, batch, return_output=False): ) # forward pass to create graph for beam search case - generations = [g for (g, s) in beam_pred_scores] + generations = [g for (g, s, _) in beam_pred_scores] gentoks = torch.nn.utils.rnn.pad_sequence( generations, batch_first=True, padding_value=self.NULL_IDX ) diff --git a/projects/light_whoami/agents/pacer.py b/projects/light_whoami/agents/pacer.py index 79277c97965..ba47b5b37ef 100644 --- a/projects/light_whoami/agents/pacer.py +++ b/projects/light_whoami/agents/pacer.py @@ -16,7 +16,6 @@ from parlai.core.opt import Opt from parlai.core.params import ParlaiParser from parlai.core.torch_generator_agent import ( - TorchGeneratorAgent, TreeSearch, GreedySearch, BeamSearch, @@ -85,7 +84,7 @@ def _get_batch_context(self, batch): return batch.text_vec return batch.full_text_vec - def _treesearch_factory(self, device: int) -> TreeSearch: + def _treesearch_factory(self, device: int, verbose=False) -> TreeSearch: method = self.opt.get('inference', 'greedy') beam_size = self.opt.get('beam_size', 1) pacer_kwargs = { @@ -105,6 +104,7 @@ def _treesearch_factory(self, device: int) -> TreeSearch: bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, **pacer_kwargs, ) elif method == 'beam': @@ -118,6 +118,7 @@ def _treesearch_factory(self, device: int) -> TreeSearch: bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, **pacer_kwargs, ) elif method == 'delayedbeam': @@ -133,6 +134,7 @@ def _treesearch_factory(self, device: int) -> TreeSearch: bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, **pacer_kwargs, ) elif method == 'topk': @@ -147,6 +149,7 @@ def _treesearch_factory(self, device: int) -> TreeSearch: bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, **pacer_kwargs, ) elif method == 'nucleus': @@ -161,6 +164,7 @@ def _treesearch_factory(self, device: int) -> TreeSearch: bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, + verbose=verbose, **pacer_kwargs, ) else: @@ -223,7 +227,7 @@ def modify_logprobs(self, logprobs: torch.Tensor) -> torch.Tensor: 1. With frequency r, select a token x_i+1 to re-rank. 2. Generate word probabilities for token x_i+1. - 3. Examine top k words {x_j | score(x_j) \in top_k(P(x_i+1 | x_0,...,x_i))}; use classifier to predict P(a|x1, ..., x_i, x_j) + 3. Examine top k words {x_j | score(x_j) in top_k(P(x_i+1 | x_0,...,x_i))}; use classifier to predict P(a|x1, ..., x_i, x_j) 4. Rescore top k words via multiplication, re-normalize, and advance the generation. :param logprobs: @@ -302,7 +306,7 @@ class PacerTopKSampling(PacerTreeSearchMixin, TopKSampling): class PacerNucleusSampling(PacerTreeSearchMixin, NucleusSampling): """ - Override Nucleus Sampling to work with PAcer + Override Nucleus Sampling to work with PAcer. """ pass @@ -330,7 +334,7 @@ def add_cmdline_args( class PacerAgent(PacerPartialOnlyAgent, RPARerankAgent): """ - PACER Agent: Combines Beam and Partial Re-ranking + PACER Agent: Combines Beam and Partial Re-ranking. """ @classmethod diff --git a/projects/light_whoami/agents/rpa_ul.py b/projects/light_whoami/agents/rpa_ul.py index c5794a7d5d2..01e27a34c6a 100644 --- a/projects/light_whoami/agents/rpa_ul.py +++ b/projects/light_whoami/agents/rpa_ul.py @@ -6,8 +6,8 @@ """ LIGHT RPA Unlikelihood Agent. -Utilizes a left-to-right RPA Classifier to predict tokens that -yield incorrect character classifications. +Utilizes a left-to-right RPA Classifier to predict tokens that yield incorrect character +classifications. """ from typing import Optional, List, Tuple from parlai.core.params import ParlaiParser @@ -36,8 +36,8 @@ class RpaUlAgent(TransformerGeneratorAgent): """ RPA UL Agent. - Performs unlikelihood such that tokens which lead to misclassification - are penalized. + Performs unlikelihood such that tokens which lead to misclassification are + penalized. """ def __init__(self, opt, shared=None): @@ -137,16 +137,19 @@ def compute_loss(self, batch, return_output=False): beam_pred_scores, _ = self._generate(batch, self.beam_size, maxlen) # forward pass to create graph for beam search case - generations = [g[1:] for (g, _) in beam_pred_scores] + generations = [g[1:] for (g, _, _) in beam_pred_scores] pred_toks = torch.nn.utils.rnn.pad_sequence(generations, batch_first=True) model_output = self.model(*self._model_input(batch), ys=pred_toks) logits, *_ = model_output # construct mask marking incorrectly classified characters label_mask = torch.zeros_like(pred_toks).type_as(logits) - label_mask, wrong_class_cnt, wrong_class_all_cnt, right_class_cnt = self.compute_ul_label_mask( - label_mask, generations, batch - ) + ( + label_mask, + wrong_class_cnt, + wrong_class_all_cnt, + right_class_cnt, + ) = self.compute_ul_label_mask(label_mask, generations, batch) # Compute unlikelihood loss ul_loss = self.compute_ul_loss(pred_toks, label_mask, logits) # type: ignore if label_mask.sum() > 0: @@ -307,7 +310,7 @@ def eval_step(self, batch): maxlen = self.label_truncate or 256 beam_preds_scores, _ = self._generate(batch, self.beam_size, maxlen) - preds, scores = zip(*beam_preds_scores) + preds, scores, _ = zip(*beam_preds_scores) cand_choices = None text = [self._v2t(p) for p in preds] if preds is not None else None diff --git a/tests/nightly/gpu/test_bb2.py b/tests/nightly/gpu/test_bb2.py index 0d1e9242718..de4b0119e20 100644 --- a/tests/nightly/gpu/test_bb2.py +++ b/tests/nightly/gpu/test_bb2.py @@ -88,8 +88,7 @@ def test_retrieval_none(self): _test_bb2_rag(KnowledgeAccessMethod.NONE, n_docs=1) -@testing_utils.skipUnlessGPU -@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") +@testing_utils.skipIfCircleCI class TestBB2Fid(unittest.TestCase): """ Test retrieval methods for BB2 with FiD. diff --git a/tests/nightly/gpu/test_rag.py b/tests/nightly/gpu/test_rag.py index 45b1b68223e..230cb3a7a70 100644 --- a/tests/nightly/gpu/test_rag.py +++ b/tests/nightly/gpu/test_rag.py @@ -110,7 +110,7 @@ } -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagDpr(unittest.TestCase): """ Test all RAG DPR Model Types with Base Generators. @@ -168,7 +168,7 @@ def test_reddit_rag_turn_thorough(self): self._test_rag_type('turn:thorough=True', 'transformer/generator', no_cuda=True) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestFidDpr(unittest.TestCase): """ Test FiD DPR Model. @@ -192,7 +192,7 @@ def test_reddit_fid(self): self._test_fid('transformer/generator') -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagDprPoly(unittest.TestCase): """ Test RAG DPR Poly model. @@ -218,7 +218,7 @@ def test_rag_turn(self): self._test_rag_type('turn', no_cuda=True) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagTfidf(unittest.TestCase): """ Test RAG TFIDF model. @@ -231,7 +231,7 @@ def test_rag_token(self): testing_utils.eval_model(opt, skip_test=True) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestFidRag(unittest.TestCase): """ Test Fid Rag. @@ -257,7 +257,7 @@ def test_reddit_fid(self): self._test_fid('transformer/generator') -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagPolyfaiss(unittest.TestCase): """ Test Rag PolyFAISS. @@ -272,7 +272,7 @@ def test_bart_rag_token(self): testing_utils.eval_model(opt, skip_test=True) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRegret(unittest.TestCase): """ Test ReGReT. @@ -296,7 +296,7 @@ def test_rag_regret_same(self): self._test_regret() -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestOtherOptions(unittest.TestCase): """ Test other RAG Options. @@ -315,7 +315,7 @@ def test_resize_embs(self): testing_utils.eval_model(opt, skip_test=True) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestQueryModels(unittest.TestCase): """ Test other RAG Options. @@ -368,7 +368,7 @@ def _test_zoo_file(mf: str, fid: bool = False, fid_rag: bool = False): torch.cuda.empty_cache() -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagZooModels(unittest.TestCase): """ Test ZOO Models. @@ -390,7 +390,7 @@ def test_bart_rag_turn_do(self): _test_zoo_file(RAG_TURN_DO_ZOO_MODEL) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestFidZooModels(unittest.TestCase): """ Test FiD zoo models. @@ -406,6 +406,7 @@ def test_bart_fid_rag_dpr_poly(self): _test_zoo_file(FID_RAG_DPR_POLY_ZOO_MODEL, True, True) +@testing_utils.skipIfCircleCI class TestLoadDPRModel(unittest.TestCase): """ Test loading different DPR models for RAG. @@ -504,7 +505,7 @@ def test_load_dpr(self): ) -@testing_utils.skipUnlessGPU +@testing_utils.skipIfCircleCI class TestRagSelfChat(unittest.TestCase): """ Test Self-Chat with RAG-based model. @@ -532,6 +533,7 @@ def test_self_chat(self): SelfChat.main(**opt) +@testing_utils.skipIfCircleCI class TestWOIChunking(unittest.TestCase): """ Test that the woi_chunk_retrieved_docs Chunker works as intended. diff --git a/tests/test_tga.py b/tests/test_tga.py index 5c6ddd38e86..f3a473db781 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -7,10 +7,18 @@ Test TorchGeneratorAgent. """ import unittest +import math +import torch from parlai.core.agents import create_agent import parlai.utils.testing as testing_utils from parlai.core.params import ParlaiParser -from parlai.core.torch_generator_agent import TorchGeneratorAgent +from parlai.core.torch_generator_agent import ( + BeamSearch, + GreedySearch, + NucleusSampling, + TopKSampling, + TorchGeneratorAgent, +) from parlai.agents.test_agents.transformer_generator_prefix import PREFIX_TEXT @@ -171,6 +179,207 @@ def test_prefix_tokens(self): PREFIX_TEXT ), f"[{beam}] does not start with [{PREFIX_TEXT}]" + def test_token_level_loss_logging(self): + """ + Test functionality of token level probability + ranking logging. + + Regression for all inference types: 'beam', 'greedy', 'topk', 'nucleus', + 'delayedbeam' + """ + inference_types = ['beam', 'greedy', 'topk', 'nucleus', 'delayedbeam'] + gold_data = { + 'beam': { + 'text_token_info': [ + ('__start__', 0.0, 1.0), + ('5', -2.5510462364763953e-05, 0.0), + ('__end__', -1.1920922133867862e-06, 0.0), + ], + 'extra_args': ['--beam-size', '3'], + }, + 'greedy': { + 'text_token_info': [ + ('__start__', 0.0, 1.0), + ('5', -2.5510462364763953e-05, 0.0), + ('__end__', -1.1920922133867862e-06, 0.0), + ], + 'extra_args': [], + }, + # sampling based token selection will produce non-deterministic output, so we can't do data regression + 'topk': {'extra_args': ['--topk', '2']}, + 'topk_multiple_beams': {'extra_args': ['--topk', '2', '--beam-size', '5']}, + # sampling based token selection will produce non-deterministic output, so we can't do data regression + 'nucleus': {'extra_args': ['--topp', '0.3']}, + 'nucleus_multiple_beams': { + 'extra_args': ['--topp', '0.3', '--beam-size', '5'] + }, + # sampling based token selection will produce non-deterministic output, so we can't do data regression + 'delayedbeam': {'extra_args': ['--topk', '2', '--beam-delay', '2']}, + } + + for inference_type in inference_types: + args = [ + '--model-file', + 'zoo:unittest/transformer_generator2/model', + '--inference', + inference_type, + '--truncate', + '1024', + '-v', + ] + gold_data[inference_type]['extra_args'] + + pp = ParlaiParser(True, True) + agent = create_agent(pp.parse_args(args), True) + obs = {'text': '5', 'episode_done': False} + agent.observe(obs) + act = agent.act() + + if 'text_token_info' in gold_data[inference_type]: + for i, tok_data in enumerate(act['text_token_info']): + assert ( + gold_data[inference_type]['text_token_info'][i][0] + == tok_data[0] + ), f"failed token prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" + assert math.isclose( + gold_data[inference_type]['text_token_info'][i][1], tok_data[1] + ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" + assert math.isclose( + gold_data[inference_type]['text_token_info'][i][2], tok_data[2] + ), f"failed token rank prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" + + def test_tree_search(self): + """ + Unit test `select_paths` for different decoding schemes. + """ + tests = { + "greedy": { + "obj": GreedySearch(beam_size=1, verbose=True), + "logprobs": torch.Tensor([[-1.0, -1.0, -0.1, -0.3]]), + "prior_scores": torch.Tensor([-0.5]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0]), + "token_ids": torch.LongTensor([2]), + "scores": torch.Tensor([-0.6]), + "token_scores": torch.Tensor([-0.1]), + "token_ranks": torch.LongTensor([0]), + }, + }, + "beam_with_one_beam": { + "obj": BeamSearch(beam_size=1, verbose=True), + "logprobs": torch.Tensor([[-1.0, -1.0, -0.1, -0.3]]), + "prior_scores": torch.Tensor([-0.5]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0]), + "token_ids": torch.LongTensor([2]), + "scores": torch.Tensor([-0.6]), + "token_scores": torch.Tensor([-0.1]), + "token_ranks": torch.LongTensor([0]), + }, + }, + "beam_with_multiple_beams": { + "obj": BeamSearch(beam_size=2, verbose=True), + "logprobs": torch.Tensor( + [[-0.1, -2.0, -3.0, -3.0], [-1.0, -1.0, -0.2, -0.3]] + ), + "prior_scores": torch.Tensor([-1.0, -0.5]), + # logprobs + prior_scores = [[-1.1,-3.,-4.,-4.],[-1.5,-1.5,-0.7,-0.8]] + "expected_result": { + "hypothesis_ids": torch.LongTensor([1, 1]), + "token_ids": torch.LongTensor([2, 3]), + "scores": torch.Tensor([-0.7, -0.8]), + "token_scores": torch.Tensor([-0.2, -0.3]), + "token_ranks": torch.LongTensor([0, 1]), + }, + }, + "topk_with_one_beam": { + "obj": TopKSampling(beam_size=1, k=3, verbose=True), + "logprobs": torch.Tensor( + [[-float('inf'), -0.5, -float('inf'), -float('inf')]] + ), + "prior_scores": torch.Tensor([-3.0]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0]), + "token_ids": torch.LongTensor([1]), + "scores": torch.Tensor([-3.5]), + "token_scores": torch.Tensor([-0.5]), + "token_ranks": torch.LongTensor([0]), + }, + }, + "topk_with_multiple_beams": { + "obj": TopKSampling(beam_size=2, k=3, verbose=True), + "logprobs": torch.Tensor( + [ + [-float('inf'), -0.5, -float('inf'), -float('inf')], + [-float('inf'), -float('inf'), -0.6, -float('inf')], + ] + ), + "prior_scores": torch.Tensor([-3.0, -2.0]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0, 1]), + "token_ids": torch.LongTensor([1, 2]), + "scores": torch.Tensor([-3.5, -2.6]), + "token_scores": torch.Tensor([-0.5, -0.6]), + "token_ranks": torch.LongTensor([0, 0]), + }, + }, + "nucleus_with_one_beam": { + "obj": NucleusSampling(beam_size=1, p=0.9, verbose=True), + "logprobs": torch.Tensor( + [[-float('inf'), -0.5, -float('inf'), -float('inf')]] + ), + "prior_scores": torch.Tensor([-3.0]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0]), + "token_ids": torch.LongTensor([1]), + "scores": torch.Tensor( + [-3.0] + ), # the -0.5 logprob normalizes to 0 in truncated distribution + "token_scores": torch.Tensor([-0.0]), # same as above + "token_ranks": torch.LongTensor([0]), + }, + }, + "nucleus_with_multiple_beams": { + "obj": NucleusSampling(beam_size=2, p=0.9, verbose=True), + "logprobs": torch.Tensor( + [ + [-float('inf'), -0.5, -float('inf'), -float('inf')], + [-float('inf'), -float('inf'), -0.6, -float('inf')], + ] + ), + "prior_scores": torch.Tensor([-3.0, -2.0]), + "expected_result": { + "hypothesis_ids": torch.LongTensor([0, 1]), + "token_ids": torch.LongTensor([1, 2]), + "scores": torch.Tensor( + [-3.0, -2.0] + ), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions + "token_scores": torch.Tensor([-0.0, -0.0]), # same as above + "token_ranks": torch.LongTensor([0, 0]), + }, + }, + } + + for test_name, test_data in tests.items(): + path_selection = test_data["obj"].select_paths( + test_data["logprobs"], test_data["prior_scores"], None + ) + expected_result = test_data["expected_result"] + + assert torch.equal( + path_selection.hypothesis_ids, expected_result["hypothesis_ids"] + ), f"failed test_tree_search for test {test_name} on field hypothesis_ids" + assert torch.equal( + path_selection.token_ids, expected_result["token_ids"] + ), f"failed test_tree_search for test {test_name} on field token_ids" + assert torch.allclose( + path_selection.scores, expected_result["scores"] + ), f"failed test_tree_search for test {test_name} on field scores" + assert torch.allclose( + path_selection.token_scores, expected_result["token_scores"] + ), f"failed test_tree_search for test {test_name} on field token_scores" + assert torch.equal( + path_selection.token_ranks, expected_result["token_ranks"] + ), f"failed test_tree_search for test {test_name} on field token_ranks" + if __name__ == '__main__': unittest.main()