From ddd4b2c2fabdee3dfeb0c1a5dd468a7c8f42a92f Mon Sep 17 00:00:00 2001 From: Colin Flaherty Date: Wed, 16 Mar 2022 14:05:46 -0700 Subject: [PATCH 1/4] passing test --- parlai/core/torch_generator_agent.py | 173 +++++++++++++-------------- parlai/utils/misc.py | 28 ----- tests/test_tga.py | 79 +++++++----- 3 files changed, 135 insertions(+), 145 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index c70ad95cd58..44dca07acab 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -17,6 +17,7 @@ * Beam class which provides some generic beam functionality for classes to use """ +from typing_extensions import TypedDict from parlai.core.params import ParlaiParser from abc import ABC, abstractmethod from typing import TypeVar, List, Dict, Optional, Tuple, Set, Iterable @@ -467,7 +468,7 @@ def __init__(self, opt: Opt, shared=None): self.beam_block_full_context = opt.get('beam_block_full_context', False) self.temperature = opt.get('temperature', 1.0) assert self.temperature > 0, '--temperature must be greater than 0' - self.output_token_losses = opt.get( + self.show_token_details = opt.get( 'verbose', False ) or 'token_losses' in opt.get('display_add_fields', '') self.compute_tokenized_bleu = opt.get('compute_tokenized_bleu', False) @@ -779,11 +780,7 @@ def _construct_label_token_losses(self, labels, model_output): def _construct_generated_token_details(self, tokens, tokens_metadata): tokens_as_txt = [self.dict[int(token)] for token in tokens] - - details_lists = [tokens_as_txt, tokens_metadata["logprobs"].tolist()] - if "ranks" in tokens_metadata: - details_lists.append(tokens_metadata["ranks"].tolist()) - return list(zip(*details_lists)) + return list(zip(tokens_as_txt, tokens_metadata)) def _compute_fairseq_bleu(self, batch: Batch, preds): """ @@ -873,7 +870,7 @@ def eval_step(self, batch): if batch.label_vec is not None: # calculate loss on targets with teacher forcing loss, model_output = self.compute_loss(batch, return_output=True) - if self.output_token_losses: + if self.show_token_details: token_losses = self._construct_label_token_losses( batch.label_vec, model_output ) @@ -896,12 +893,12 @@ def eval_step(self, batch): beam_texts_token_info: List[List[List[Tuple]]] = [] for beam in beams: beam_texts.append([]) - if self.output_token_losses: + if self.show_token_details: beam_texts_token_info.append([]) for tokens, score, token_metadata in beam.get_rescored_finished(): try: - if self.output_token_losses: + if self.show_token_details: beam_texts_token_info[-1].append( self._construct_generated_token_details( tokens, token_metadata @@ -923,7 +920,7 @@ def eval_step(self, batch): else None ) - if self.output_token_losses and beam_preds_scores is not None: + if self.show_token_details and beam_preds_scores is not None: text_token_info = [] for beam_text_token_info in beam_texts_token_info: text_token_info.append(beam_text_token_info[0]) @@ -1144,14 +1141,14 @@ def _generate( batchsize = batch.batchsize batch_context_list = self._get_batch_context(batch).tolist() beams = [ - self._treesearch_factory(dev, verbose=self.output_token_losses) + self._treesearch_factory(dev, verbose=self.show_token_details) .set_batch_context(batch_context_list, batch_idx) .set_block_list(self.beam_block_list) for batch_idx in range(batchsize) ] else: beams = [ - self._treesearch_factory(dev, verbose=self.output_token_losses) + self._treesearch_factory(dev, verbose=self.show_token_details) for _ in range(bsz) ] @@ -1246,15 +1243,19 @@ class _HypothesisTail(object): """ # use slots because we don't want dynamic attributes here - __slots__ = ['timestep', 'hypid', 'score', 'tokenid', 'token_score', 'token_rank'] + __slots__ = ['timestep', 'hypid', 'score', 'tokenid', 'token_details'] - def __init__(self, timestep, hypid, score, tokenid, token_score, token_rank): + def __init__(self, timestep, hypid, score, tokenid, token_details): self.timestep = timestep self.hypid = hypid self.score = score self.tokenid = tokenid - self.token_score = token_score - self.token_rank = token_rank + self.token_details = token_details + + +class _PathSelectionTokenDetails(TypedDict, total=False): + token_score: float + token_rank: int class _PathSelection(object): @@ -1264,14 +1265,19 @@ class _PathSelection(object): Represents output of path selection process. """ - __slots__ = ['hypothesis_ids', 'token_ids', 'scores', 'token_scores', 'token_ranks'] + __slots__ = ['hypothesis_ids', 'token_ids', 'scores', 'token_details'] - def __init__(self, hypothesis_ids, token_ids, scores, token_scores, token_ranks): + def __init__( + self, + hypothesis_ids, + token_ids, + scores, + token_details: Optional[List[_PathSelectionTokenDetails]] = None, + ): self.hypothesis_ids = hypothesis_ids self.token_ids = token_ids self.scores = scores - self.token_scores = token_scores - self.token_ranks = token_ranks + self.token_details = token_details # length equal to beam size class TreeSearch(object): @@ -1339,16 +1345,12 @@ def __init__( ] self.verbose = verbose - self.token_scores, self.token_ranks = None, None + self.token_details: Optional[List[List[_PathSelectionTokenDetails]]] = None if self.verbose: - # beam token scores - self.token_scores = torch.zeros( - (self.beam_size, 1), device=self.device - ) # log prob of bos token is 0 - # beam token ranks - self.token_ranks = torch.ones( - (self.beam_size, 1), device=self.device - ) # bos token is prob 1 and so rank 1 + self.token_details = [] + for _ in range(self.beam_size): + self.token_details.append([{"token_score": 0.0, "token_rank": 1}]) + # keeps tuples (score, time_step, hyp_id) self.finished = [] self.eos_top = False @@ -1413,7 +1415,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection :param current_length: the current length in tokens :return: - a {hypothesis_ids, token_ids, scores, token_scores, token_ranks} , where: + a {hypothesis_ids, token_ids, scores, token_details} , where: - hypothesis_ids is a LongTensor of hypotheses we're extending. May have repeats, but should always be (beamsize) long. @@ -1421,10 +1423,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection each of the hypotheses. - scores is a (beamsize) Tensor with the updated cumulative log-probs of each beam. - - token_scores is a (beamsize) Tensor with the log-probs of the next-token choices for - each of the hypotheses. - - token_ranks is a (beamsize) Tensor with the ranks of the next-token choices for - each of the hypotheses. + - token_details is a (beamsize) list of objects with with metadata about each generated token. """ pass @@ -1515,12 +1514,10 @@ def advance(self, logprobs): ] if self.verbose: - self.token_scores = torch.cat( - (self.token_scores, path_selection.token_scores.unsqueeze(1)), dim=1 - ) - self.token_ranks = torch.cat( - (self.token_ranks, path_selection.token_ranks.unsqueeze(1)), dim=1 - ) + assert path_selection.token_details + assert self.token_details + for i in range(self.beam_size): + self.token_details[i].append(path_selection.token_details[i]) # check new hypos for eos label, if we have some, add to finished for hypid in range(self.beam_size): @@ -1528,16 +1525,14 @@ def advance(self, logprobs): if self.scores[hypid] <= neginf(self.scores.dtype): continue # this is finished hypo, adding to finished + eostail = _HypothesisTail( timestep=len(self.outputs) - 1, hypid=hypid, score=self.all_scores[-1][hypid], tokenid=self.eos, - token_score=self.token_scores[hypid, -1] - if self.token_scores is not None - else None, - token_rank=self.token_ranks[hypid, -1] - if self.token_ranks is not None + token_details=self.token_details[hypid][-1] + if self.token_details is not None else None, ) self.finished.append(eostail) @@ -1583,11 +1578,8 @@ def _get_hyp_from_finished(self, hypothesis_tail): hypid=endback, score=self.all_scores[i][endback], tokenid=self.outputs[i][endback], - token_score=self.token_scores[endback, i] - if self.token_scores is not None - else None, - token_rank=self.token_ranks[endback, i] - if self.token_ranks is not None + token_details=self.token_details[endback][i] + if self.token_details is not None else None, ) ) @@ -1601,18 +1593,6 @@ def _get_pretty_hypothesis(self, list_of_hypotails): """ return torch.stack([ht.tokenid for ht in reversed(list_of_hypotails)]) - def _get_pretty_token_metadata(self, list_of_hypotails): - """ - Return token probabilities and ranks as two tensors. - """ - - return { - "logprobs": torch.stack( - [ht.token_score for ht in reversed(list_of_hypotails)] - ), - "ranks": torch.stack([ht.token_rank for ht in reversed(list_of_hypotails)]), - } - def get_rescored_finished(self, n_best=None): """ Return finished hypotheses according to adjusted scores. @@ -1640,11 +1620,8 @@ def get_rescored_finished(self, n_best=None): hypid=0, score=self.all_scores[-1][0], tokenid=self.outputs[-1][0], - token_score=self.token_scores[0, -1] - if self.token_scores is not None - else None, - token_rank=self.token_ranks[0, -1] - if self.token_ranks is not None + token_details=self.token_details[0][-1] + if self.token_details is not None else None, ) ) @@ -1660,8 +1637,7 @@ def get_rescored_finished(self, n_best=None): hypid=finished_item.hypid, score=finished_item.score / length_penalty, tokenid=finished_item.tokenid, - token_score=finished_item.token_score, - token_rank=finished_item.token_rank, + token_details=finished_item.token_details, ) ) @@ -1676,7 +1652,9 @@ def get_rescored_finished(self, n_best=None): hyp_data = self._get_hyp_from_finished(hyp) token_ids = self._get_pretty_hypothesis(hyp_data) token_metadata = ( - self._get_pretty_token_metadata(hyp_data) if self.verbose else None + [tok.token_details for tok in reversed(hyp_data)] + if self.verbose + else None ) n_best_list.append((token_ids, hyp.score, token_metadata)) @@ -1712,17 +1690,19 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection best_scores = tok_scores + prior_scores hyp_ids = torch.arange(logprobs.size(0), device=logprobs.device) - tok_ranks = None + token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = tok_scores.view(-1) - tok_ranks = torch.tensor([0], device=logprobs.device, dtype=torch.long) + tok_score = tok_scores[0].item() + tok_rank = 0 + token_details: Optional[List[_PathSelectionTokenDetails]] = [ + {"token_score": tok_score, "token_rank": tok_rank} + ] return _PathSelection( hypothesis_ids=hyp_ids, token_ids=tok_ids, scores=best_scores, - token_scores=tok_scores, - token_ranks=tok_ranks, + token_details=token_details, ) @@ -1750,14 +1730,13 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection # get the actual word id from residual of the same division tok_ids = best_idxs % voc_size - tok_scores, tok_ranks = None, None + token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: tok_scores = ( torch.index_select(logprobs, 0, hyp_ids) .gather(1, tok_ids.unsqueeze(1)) .view(-1) ) - tok_ranks = ( logprobs.argsort(1, descending=True) .argsort(1) @@ -1765,12 +1744,18 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection .gather(0, best_idxs) ) + token_details = [] + + for score, rank in zip(tok_scores.cpu().numpy(), tok_ranks.cpu().numpy()): + token_details.append( + {"token_score": score.item(), "token_rank": int(rank.item())} + ) + return _PathSelection( hypothesis_ids=hyp_ids, token_ids=tok_ids, scores=best_scores, - token_scores=tok_scores, - token_ranks=tok_ranks, + token_details=token_details, ) @@ -1824,17 +1809,22 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection scores = values[hyp_ids, choices] best_scores = prior_scores.expand_as(scores) + scores - tok_scores, tok_ranks = None, None + token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = scores.view(-1) - tok_ranks = choices.view(-1) + tok_scores = scores.view(-1).cpu().numpy() + tok_ranks = choices.view(-1).cpu().numpy() + token_details = [] + + for tok_score, tok_rank in zip(tok_scores, tok_ranks): + token_details.append( + {"token_score": tok_score, "token_rank": int(tok_rank)} + ) return _PathSelection( hypothesis_ids=hyp_ids, token_ids=tok_ids, scores=best_scores, - token_scores=tok_scores, - token_ranks=tok_ranks, + token_details=token_details, ) @@ -1871,15 +1861,20 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection scores = sprobs[hyp_ids, choices].log() best_scores = prior_scores.expand_as(scores) + scores - tok_scores, tok_ranks = None, None + token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = scores.view(-1) - tok_ranks = choices.view(-1) + tok_scores = scores.view(-1).cpu().numpy() + tok_ranks = choices.view(-1).cpu().numpy() + token_details = [] + + for tok_score, tok_rank in zip(tok_scores, tok_ranks): + token_details.append( + {"token_score": tok_score, "token_rank": int(tok_rank)} + ) return _PathSelection( hypothesis_ids=hyp_ids, token_ids=tok_ids, scores=best_scores, - token_scores=tok_scores, - token_ranks=tok_ranks, + token_details=token_details, ) diff --git a/parlai/utils/misc.py b/parlai/utils/misc.py index 3f882e4e303..66435c69693 100644 --- a/parlai/utils/misc.py +++ b/parlai/utils/misc.py @@ -43,7 +43,6 @@ 'text_candidates', 'reward', 'token_losses', - 'generated_text_token_info', 'metrics', } @@ -523,29 +522,6 @@ def _token_losses_line( ) return _pretty_lines(space, key, formatted_tl, 'text2') - def _text_token_info_line( - msg: Dict[str, Any], fields_to_show: List[str], space: str - ) -> Optional[str]: - """ - Displays the loss associated with each token. Can be used for debugging - generative models. - - See TorchGeneratorAgent._generate for an example implementation. - """ - key = 'text_token_info' - text_token_info = msg.get(key, None) - - if key not in fields_to_show or not text_token_info: - return None - # Reduce losses to 4 significant figures - formatted_tl = ' | '.join( - [ - f"{tl[0]} {float('{:.4g}'.format(tl[1]))} {tl[2]}" - for tl in text_token_info - ] - ) - return _pretty_lines(space, key, formatted_tl, 'text2') - def _pretty_lines(indent_space, field, value, style): line = '{}{} {}'.format( indent_space, colorize('[' + field + ']:', 'field'), colorize(value, style) @@ -640,10 +616,6 @@ def _pretty_lines(indent_space, field, value, style): if token_loss_line: lines.append(token_loss_line) - text_token_info_line = _text_token_info_line(msg, fields_to_show, space) - if text_token_info_line: - lines.append(text_token_info_line) - if episode_done: lines.append( colorize('- - - - - - - END OF EPISODE - - - - - - - - - -', 'highlight') diff --git a/tests/test_tga.py b/tests/test_tga.py index f3a473db781..92516501304 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -190,17 +190,23 @@ def test_token_level_loss_logging(self): gold_data = { 'beam': { 'text_token_info': [ - ('__start__', 0.0, 1.0), - ('5', -2.5510462364763953e-05, 0.0), - ('__end__', -1.1920922133867862e-06, 0.0), + ('__start__', {"token_score": 0.0, "token_rank": 1}), + ('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}), + ( + '__end__', + {"token_score": -1.1920922133867862e-06, "token_rank": 0}, + ), ], 'extra_args': ['--beam-size', '3'], }, 'greedy': { 'text_token_info': [ - ('__start__', 0.0, 1.0), - ('5', -2.5510462364763953e-05, 0.0), - ('__end__', -1.1920922133867862e-06, 0.0), + ('__start__', {"token_score": 0.0, "token_rank": 1}), + ('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}), + ( + '__end__', + {"token_score": -1.1920922133867862e-06, "token_rank": 0}, + ), ], 'extra_args': [], }, @@ -240,10 +246,16 @@ def test_token_level_loss_logging(self): == tok_data[0] ), f"failed token prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( - gold_data[inference_type]['text_token_info'][i][1], tok_data[1] + gold_data[inference_type]['text_token_info'][i][1][ + "token_score" + ], + tok_data[1]["token_score"], ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( - gold_data[inference_type]['text_token_info'][i][2], tok_data[2] + gold_data[inference_type]['text_token_info'][i][1][ + "token_rank" + ], + tok_data[1]["token_rank"], ), f"failed token rank prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" def test_tree_search(self): @@ -259,8 +271,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_scores": torch.Tensor([-0.1]), - "token_ranks": torch.LongTensor([0]), + "token_details": [{"token_score": -0.1, "token_rank": 0}], }, }, "beam_with_one_beam": { @@ -271,8 +282,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_scores": torch.Tensor([-0.1]), - "token_ranks": torch.LongTensor([0]), + "token_details": [{"token_score": -0.1, "token_rank": 0}], }, }, "beam_with_multiple_beams": { @@ -286,8 +296,10 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([1, 1]), "token_ids": torch.LongTensor([2, 3]), "scores": torch.Tensor([-0.7, -0.8]), - "token_scores": torch.Tensor([-0.2, -0.3]), - "token_ranks": torch.LongTensor([0, 1]), + "token_details": [ + {"token_score": -0.2, "token_rank": 0}, + {"token_score": -0.3, "token_rank": 1}, + ], }, }, "topk_with_one_beam": { @@ -300,8 +312,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([1]), "scores": torch.Tensor([-3.5]), - "token_scores": torch.Tensor([-0.5]), - "token_ranks": torch.LongTensor([0]), + "token_details": [{"token_score": -0.5, "token_rank": 0}], }, }, "topk_with_multiple_beams": { @@ -317,8 +328,10 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0, 1]), "token_ids": torch.LongTensor([1, 2]), "scores": torch.Tensor([-3.5, -2.6]), - "token_scores": torch.Tensor([-0.5, -0.6]), - "token_ranks": torch.LongTensor([0, 0]), + "token_details": [ + {"token_score": -0.5, "token_rank": 0}, + {"token_score": -0.6, "token_rank": 0}, + ], }, }, "nucleus_with_one_beam": { @@ -333,8 +346,7 @@ def test_tree_search(self): "scores": torch.Tensor( [-3.0] ), # the -0.5 logprob normalizes to 0 in truncated distribution - "token_scores": torch.Tensor([-0.0]), # same as above - "token_ranks": torch.LongTensor([0]), + "token_details": [{"token_score": -0.0, "token_rank": 0}], }, }, "nucleus_with_multiple_beams": { @@ -352,8 +364,10 @@ def test_tree_search(self): "scores": torch.Tensor( [-3.0, -2.0] ), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions - "token_scores": torch.Tensor([-0.0, -0.0]), # same as above - "token_ranks": torch.LongTensor([0, 0]), + "token_details": [ + {"token_score": -0.0, "token_rank": 0}, + {"token_score": -0.0, "token_rank": 0}, + ], }, }, } @@ -373,12 +387,21 @@ def test_tree_search(self): assert torch.allclose( path_selection.scores, expected_result["scores"] ), f"failed test_tree_search for test {test_name} on field scores" - assert torch.allclose( - path_selection.token_scores, expected_result["token_scores"] - ), f"failed test_tree_search for test {test_name} on field token_scores" - assert torch.equal( - path_selection.token_ranks, expected_result["token_ranks"] - ), f"failed test_tree_search for test {test_name} on field token_ranks" + + assert len(path_selection.token_details) == len( + expected_result["token_details"] + ), f"failed test_tree_search for test {test_name} on field token_details" + for token_details, expected_token_details in zip( + path_selection.token_details, expected_result["token_details"] + ): + assert math.isclose( + token_details["token_score"], + expected_token_details["token_score"], + rel_tol=1e-5, + ), f"failed test_tree_search for test {test_name} on field token_details" + assert ( + token_details["token_rank"] == expected_token_details["token_rank"] + ), f"failed test_tree_search for test {test_name} on field token_details" if __name__ == '__main__': From 58fd67c6f055db3f5db1219dee0626c3509d6bd4 Mon Sep 17 00:00:00 2001 From: Colin Flaherty Date: Wed, 16 Mar 2022 15:22:09 -0700 Subject: [PATCH 2/4] dont truncate nucleus sampling distr and output normalized probs instead of logprobs --- parlai/core/torch_generator_agent.py | 20 ++++++++------- tests/test_tga.py | 37 ++++++++++++---------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 44dca07acab..950df51f0d0 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1692,7 +1692,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_score = tok_scores[0].item() + tok_score = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].item() tok_rank = 0 token_details: Optional[List[_PathSelectionTokenDetails]] = [ {"token_score": tok_score, "token_rank": tok_rank} @@ -1732,13 +1732,14 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: + probs = torch.softmax(logprobs, dim=-1) tok_scores = ( - torch.index_select(logprobs, 0, hyp_ids) + torch.index_select(probs, 0, hyp_ids) .gather(1, tok_ids.unsqueeze(1)) .view(-1) ) tok_ranks = ( - logprobs.argsort(1, descending=True) + probs.argsort(1, descending=True) .argsort(1) .view(-1) .gather(0, best_idxs) @@ -1811,7 +1812,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = scores.view(-1).cpu().numpy() + tok_scores = probs[hyp_ids, choices].view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] @@ -1852,18 +1853,19 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection # The subtraction here is to get the exclusive prefix sum, # to guarantee the first element is not masked mask = (sprobs.cumsum(dim=-1) - sprobs) >= self.p - sprobs[mask] = 0 - sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1)) - choices = torch.multinomial(sprobs, 1)[:, 0] + trunc_sprobs = sprobs.detach().clone() + trunc_sprobs[mask] = 0 + trunc_sprobs.div_(trunc_sprobs.sum(dim=-1).unsqueeze(1)) + choices = torch.multinomial(trunc_sprobs, 1)[:, 0] hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device) tok_ids = sinds[hyp_ids, choices] # Convert back to logspace. - scores = sprobs[hyp_ids, choices].log() + scores = trunc_sprobs[hyp_ids, choices].log() best_scores = prior_scores.expand_as(scores) + scores token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = scores.view(-1).cpu().numpy() + tok_scores = sprobs[hyp_ids, choices].view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] diff --git a/tests/test_tga.py b/tests/test_tga.py index 92516501304..fe9ba98d83b 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -191,22 +191,16 @@ def test_token_level_loss_logging(self): 'beam': { 'text_token_info': [ ('__start__', {"token_score": 0.0, "token_rank": 1}), - ('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}), - ( - '__end__', - {"token_score": -1.1920922133867862e-06, "token_rank": 0}, - ), + ('5', {"token_score": 0.999, "token_rank": 0}), + ('__end__', {"token_score": 0.999, "token_rank": 0}), ], 'extra_args': ['--beam-size', '3'], }, 'greedy': { 'text_token_info': [ ('__start__', {"token_score": 0.0, "token_rank": 1}), - ('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}), - ( - '__end__', - {"token_score": -1.1920922133867862e-06, "token_rank": 0}, - ), + ('5', {"token_score": 0.999, "token_rank": 0}), + ('__end__', {"token_score": 0.999, "token_rank": 0}), ], 'extra_args': [], }, @@ -250,6 +244,7 @@ def test_token_level_loss_logging(self): "token_score" ], tok_data[1]["token_score"], + rel_tol=1e-3, ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( gold_data[inference_type]['text_token_info'][i][1][ @@ -271,7 +266,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_score": -0.1, "token_rank": 0}], + "token_details": [{"token_score": 0.3800, "token_rank": 0}], }, }, "beam_with_one_beam": { @@ -282,7 +277,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_score": -0.1, "token_rank": 0}], + "token_details": [{"token_score": 0.3800, "token_rank": 0}], }, }, "beam_with_multiple_beams": { @@ -297,8 +292,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([2, 3]), "scores": torch.Tensor([-0.7, -0.8]), "token_details": [ - {"token_score": -0.2, "token_rank": 0}, - {"token_score": -0.3, "token_rank": 1}, + {"token_score": 0.3567, "token_rank": 0}, + {"token_score": 0.3228, "token_rank": 1}, ], }, }, @@ -312,7 +307,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([1]), "scores": torch.Tensor([-3.5]), - "token_details": [{"token_score": -0.5, "token_rank": 0}], + "token_details": [{"token_score": 1.0, "token_rank": 0}], }, }, "topk_with_multiple_beams": { @@ -329,8 +324,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([1, 2]), "scores": torch.Tensor([-3.5, -2.6]), "token_details": [ - {"token_score": -0.5, "token_rank": 0}, - {"token_score": -0.6, "token_rank": 0}, + {"token_score": 1.0, "token_rank": 0}, + {"token_score": 1.0, "token_rank": 0}, ], }, }, @@ -346,7 +341,7 @@ def test_tree_search(self): "scores": torch.Tensor( [-3.0] ), # the -0.5 logprob normalizes to 0 in truncated distribution - "token_details": [{"token_score": -0.0, "token_rank": 0}], + "token_details": [{"token_score": 1.0, "token_rank": 0}], }, }, "nucleus_with_multiple_beams": { @@ -365,8 +360,8 @@ def test_tree_search(self): [-3.0, -2.0] ), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions "token_details": [ - {"token_score": -0.0, "token_rank": 0}, - {"token_score": -0.0, "token_rank": 0}, + {"token_score": 1.0, "token_rank": 0}, + {"token_score": 1.0, "token_rank": 0}, ], }, }, @@ -397,7 +392,7 @@ def test_tree_search(self): assert math.isclose( token_details["token_score"], expected_token_details["token_score"], - rel_tol=1e-5, + rel_tol=1e-3, ), f"failed test_tree_search for test {test_name} on field token_details" assert ( token_details["token_rank"] == expected_token_details["token_rank"] From 3b5675aee0dc8bf3f36004c857708372363da72d Mon Sep 17 00:00:00 2001 From: Colin Flaherty Date: Thu, 17 Mar 2022 08:32:29 -0700 Subject: [PATCH 3/4] address emily feedback --- parlai/core/torch_generator_agent.py | 32 +++++++++++----------- tests/test_tga.py | 40 ++++++++++++++-------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 950df51f0d0..9cbe4ec81b3 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1254,8 +1254,8 @@ def __init__(self, timestep, hypid, score, tokenid, token_details): class _PathSelectionTokenDetails(TypedDict, total=False): - token_score: float - token_rank: int + token_prob: float # conditional probability of token (normalized) + token_rank: int # rank of token in conditional distribution class _PathSelection(object): @@ -1349,7 +1349,7 @@ def __init__( if self.verbose: self.token_details = [] for _ in range(self.beam_size): - self.token_details.append([{"token_score": 0.0, "token_rank": 1}]) + self.token_details.append([{"token_prob": 0.0, "token_rank": 1}]) # keeps tuples (score, time_step, hyp_id) self.finished = [] @@ -1692,11 +1692,9 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_score = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].item() + tok_prob = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].item() tok_rank = 0 - token_details: Optional[List[_PathSelectionTokenDetails]] = [ - {"token_score": tok_score, "token_rank": tok_rank} - ] + token_details = [{"token_prob": tok_prob, "token_rank": tok_rank}] return _PathSelection( hypothesis_ids=hyp_ids, @@ -1733,7 +1731,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: probs = torch.softmax(logprobs, dim=-1) - tok_scores = ( + tok_probs = ( torch.index_select(probs, 0, hyp_ids) .gather(1, tok_ids.unsqueeze(1)) .view(-1) @@ -1747,9 +1745,11 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details = [] - for score, rank in zip(tok_scores.cpu().numpy(), tok_ranks.cpu().numpy()): + for tok_prob, tok_rank in zip( + tok_probs.cpu().numpy(), tok_ranks.cpu().numpy() + ): token_details.append( - {"token_score": score.item(), "token_rank": int(rank.item())} + {"token_prob": tok_prob.item(), "token_rank": int(tok_rank.item())} ) return _PathSelection( @@ -1812,13 +1812,13 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = probs[hyp_ids, choices].view(-1).cpu().numpy() + tok_probs = probs[hyp_ids, choices].view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] - for tok_score, tok_rank in zip(tok_scores, tok_ranks): + for tok_prob, tok_rank in zip(tok_probs, tok_ranks): token_details.append( - {"token_score": tok_score, "token_rank": int(tok_rank)} + {"token_prob": tok_prob, "token_rank": int(tok_rank)} ) return _PathSelection( @@ -1865,13 +1865,13 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_scores = sprobs[hyp_ids, choices].view(-1).cpu().numpy() + tok_probs = sprobs[hyp_ids, choices].view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] - for tok_score, tok_rank in zip(tok_scores, tok_ranks): + for tok_prob, tok_rank in zip(tok_probs, tok_ranks): token_details.append( - {"token_score": tok_score, "token_rank": int(tok_rank)} + {"token_prob": tok_prob, "token_rank": int(tok_rank)} ) return _PathSelection( diff --git a/tests/test_tga.py b/tests/test_tga.py index fe9ba98d83b..384e263f147 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -190,17 +190,17 @@ def test_token_level_loss_logging(self): gold_data = { 'beam': { 'text_token_info': [ - ('__start__', {"token_score": 0.0, "token_rank": 1}), - ('5', {"token_score": 0.999, "token_rank": 0}), - ('__end__', {"token_score": 0.999, "token_rank": 0}), + ('__start__', {"token_prob": 0.0, "token_rank": 1}), + ('5', {"token_prob": 0.999, "token_rank": 0}), + ('__end__', {"token_prob": 0.999, "token_rank": 0}), ], 'extra_args': ['--beam-size', '3'], }, 'greedy': { 'text_token_info': [ - ('__start__', {"token_score": 0.0, "token_rank": 1}), - ('5', {"token_score": 0.999, "token_rank": 0}), - ('__end__', {"token_score": 0.999, "token_rank": 0}), + ('__start__', {"token_prob": 0.0, "token_rank": 1}), + ('5', {"token_prob": 0.999, "token_rank": 0}), + ('__end__', {"token_prob": 0.999, "token_rank": 0}), ], 'extra_args': [], }, @@ -241,9 +241,9 @@ def test_token_level_loss_logging(self): ), f"failed token prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( gold_data[inference_type]['text_token_info'][i][1][ - "token_score" + "token_prob" ], - tok_data[1]["token_score"], + tok_data[1]["token_prob"], rel_tol=1e-3, ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( @@ -266,7 +266,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_score": 0.3800, "token_rank": 0}], + "token_details": [{"token_prob": 0.3800, "token_rank": 0}], }, }, "beam_with_one_beam": { @@ -277,7 +277,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_score": 0.3800, "token_rank": 0}], + "token_details": [{"token_prob": 0.3800, "token_rank": 0}], }, }, "beam_with_multiple_beams": { @@ -292,8 +292,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([2, 3]), "scores": torch.Tensor([-0.7, -0.8]), "token_details": [ - {"token_score": 0.3567, "token_rank": 0}, - {"token_score": 0.3228, "token_rank": 1}, + {"token_prob": 0.3567, "token_rank": 0}, + {"token_prob": 0.3228, "token_rank": 1}, ], }, }, @@ -307,7 +307,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([1]), "scores": torch.Tensor([-3.5]), - "token_details": [{"token_score": 1.0, "token_rank": 0}], + "token_details": [{"token_prob": 1.0, "token_rank": 0}], }, }, "topk_with_multiple_beams": { @@ -324,8 +324,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([1, 2]), "scores": torch.Tensor([-3.5, -2.6]), "token_details": [ - {"token_score": 1.0, "token_rank": 0}, - {"token_score": 1.0, "token_rank": 0}, + {"token_prob": 1.0, "token_rank": 0}, + {"token_prob": 1.0, "token_rank": 0}, ], }, }, @@ -341,7 +341,7 @@ def test_tree_search(self): "scores": torch.Tensor( [-3.0] ), # the -0.5 logprob normalizes to 0 in truncated distribution - "token_details": [{"token_score": 1.0, "token_rank": 0}], + "token_details": [{"token_prob": 1.0, "token_rank": 0}], }, }, "nucleus_with_multiple_beams": { @@ -360,8 +360,8 @@ def test_tree_search(self): [-3.0, -2.0] ), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions "token_details": [ - {"token_score": 1.0, "token_rank": 0}, - {"token_score": 1.0, "token_rank": 0}, + {"token_prob": 1.0, "token_rank": 0}, + {"token_prob": 1.0, "token_rank": 0}, ], }, }, @@ -390,8 +390,8 @@ def test_tree_search(self): path_selection.token_details, expected_result["token_details"] ): assert math.isclose( - token_details["token_score"], - expected_token_details["token_score"], + token_details["token_prob"], + expected_token_details["token_prob"], rel_tol=1e-3, ), f"failed test_tree_search for test {test_name} on field token_details" assert ( From 685eb233f36f6a8cf7f1c73bc2ea7c13f96d3734 Mon Sep 17 00:00:00 2001 From: Colin Flaherty Date: Thu, 24 Mar 2022 14:57:16 -0700 Subject: [PATCH 4/4] addressed emily and stephen pr feedback --- parlai/core/torch_generator_agent.py | 30 +++++++++-------- tests/test_tga.py | 50 +++++++++++++++------------- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 9cbe4ec81b3..6b5d7343213 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1254,7 +1254,7 @@ def __init__(self, timestep, hypid, score, tokenid, token_details): class _PathSelectionTokenDetails(TypedDict, total=False): - token_prob: float # conditional probability of token (normalized) + token_logprob: float # conditional log-probability of token (normalized) token_rank: int # rank of token in conditional distribution @@ -1345,11 +1345,12 @@ def __init__( ] self.verbose = verbose + # (beam size, sample length) list of lists containing token-level data for each token in each hypo in the beam self.token_details: Optional[List[List[_PathSelectionTokenDetails]]] = None if self.verbose: self.token_details = [] for _ in range(self.beam_size): - self.token_details.append([{"token_prob": 0.0, "token_rank": 1}]) + self.token_details.append([{"token_logprob": 0.0, "token_rank": 0}]) # keeps tuples (score, time_step, hyp_id) self.finished = [] @@ -1692,9 +1693,9 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_prob = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].item() + tok_logprob = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].log().item() tok_rank = 0 - token_details = [{"token_prob": tok_prob, "token_rank": tok_rank}] + token_details = [{"token_logprob": tok_logprob, "token_rank": tok_rank}] return _PathSelection( hypothesis_ids=hyp_ids, @@ -1745,11 +1746,14 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details = [] - for tok_prob, tok_rank in zip( - tok_probs.cpu().numpy(), tok_ranks.cpu().numpy() + for tok_logprob, tok_rank in zip( + tok_probs.log().cpu().numpy(), tok_ranks.cpu().numpy() ): token_details.append( - {"token_prob": tok_prob.item(), "token_rank": int(tok_rank.item())} + { + "token_logprob": tok_logprob.item(), + "token_rank": int(tok_rank.item()), + } ) return _PathSelection( @@ -1812,13 +1816,13 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_probs = probs[hyp_ids, choices].view(-1).cpu().numpy() + tok_logprobs = probs[hyp_ids, choices].log().view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] - for tok_prob, tok_rank in zip(tok_probs, tok_ranks): + for tok_logprob, tok_rank in zip(tok_logprobs, tok_ranks): token_details.append( - {"token_prob": tok_prob, "token_rank": int(tok_rank)} + {"token_logprob": tok_logprob, "token_rank": int(tok_rank)} ) return _PathSelection( @@ -1865,13 +1869,13 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection token_details: Optional[List[_PathSelectionTokenDetails]] = None if self.verbose: - tok_probs = sprobs[hyp_ids, choices].view(-1).cpu().numpy() + tok_logprobs = sprobs[hyp_ids, choices].log().view(-1).cpu().numpy() tok_ranks = choices.view(-1).cpu().numpy() token_details = [] - for tok_prob, tok_rank in zip(tok_probs, tok_ranks): + for tok_logprob, tok_rank in zip(tok_logprobs, tok_ranks): token_details.append( - {"token_prob": tok_prob, "token_rank": int(tok_rank)} + {"token_logprob": tok_logprob, "token_rank": int(tok_rank)} ) return _PathSelection( diff --git a/tests/test_tga.py b/tests/test_tga.py index 384e263f147..0b67a1d320a 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -190,17 +190,17 @@ def test_token_level_loss_logging(self): gold_data = { 'beam': { 'text_token_info': [ - ('__start__', {"token_prob": 0.0, "token_rank": 1}), - ('5', {"token_prob": 0.999, "token_rank": 0}), - ('__end__', {"token_prob": 0.999, "token_rank": 0}), + ('__start__', {"token_logprob": 0.0, "token_rank": 0}), + ('5', {"token_logprob": math.log(0.999), "token_rank": 0}), + ('__end__', {"token_logprob": math.log(0.999), "token_rank": 0}), ], 'extra_args': ['--beam-size', '3'], }, 'greedy': { 'text_token_info': [ - ('__start__', {"token_prob": 0.0, "token_rank": 1}), - ('5', {"token_prob": 0.999, "token_rank": 0}), - ('__end__', {"token_prob": 0.999, "token_rank": 0}), + ('__start__', {"token_logprob": 0.0, "token_rank": 0}), + ('5', {"token_logprob": math.log(0.999), "token_rank": 0}), + ('__end__', {"token_logprob": math.log(0.999), "token_rank": 0}), ], 'extra_args': [], }, @@ -241,11 +241,11 @@ def test_token_level_loss_logging(self): ), f"failed token prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( gold_data[inference_type]['text_token_info'][i][1][ - "token_prob" + "token_logprob" ], - tok_data[1]["token_prob"], - rel_tol=1e-3, - ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" + tok_data[1]["token_logprob"], + abs_tol=1e-3, + ), f"failed token log-probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}" assert math.isclose( gold_data[inference_type]['text_token_info'][i][1][ "token_rank" @@ -266,7 +266,9 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_prob": 0.3800, "token_rank": 0}], + "token_details": [ + {"token_logprob": math.log(0.3800), "token_rank": 0} + ], }, }, "beam_with_one_beam": { @@ -277,7 +279,9 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([2]), "scores": torch.Tensor([-0.6]), - "token_details": [{"token_prob": 0.3800, "token_rank": 0}], + "token_details": [ + {"token_logprob": math.log(0.3800), "token_rank": 0} + ], }, }, "beam_with_multiple_beams": { @@ -292,8 +296,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([2, 3]), "scores": torch.Tensor([-0.7, -0.8]), "token_details": [ - {"token_prob": 0.3567, "token_rank": 0}, - {"token_prob": 0.3228, "token_rank": 1}, + {"token_logprob": math.log(0.3567), "token_rank": 0}, + {"token_logprob": math.log(0.3228), "token_rank": 1}, ], }, }, @@ -307,7 +311,7 @@ def test_tree_search(self): "hypothesis_ids": torch.LongTensor([0]), "token_ids": torch.LongTensor([1]), "scores": torch.Tensor([-3.5]), - "token_details": [{"token_prob": 1.0, "token_rank": 0}], + "token_details": [{"token_logprob": 0.0, "token_rank": 0}], }, }, "topk_with_multiple_beams": { @@ -324,8 +328,8 @@ def test_tree_search(self): "token_ids": torch.LongTensor([1, 2]), "scores": torch.Tensor([-3.5, -2.6]), "token_details": [ - {"token_prob": 1.0, "token_rank": 0}, - {"token_prob": 1.0, "token_rank": 0}, + {"token_logprob": 0.0, "token_rank": 0}, + {"token_logprob": 0.0, "token_rank": 0}, ], }, }, @@ -341,7 +345,7 @@ def test_tree_search(self): "scores": torch.Tensor( [-3.0] ), # the -0.5 logprob normalizes to 0 in truncated distribution - "token_details": [{"token_prob": 1.0, "token_rank": 0}], + "token_details": [{"token_logprob": 0.0, "token_rank": 0}], }, }, "nucleus_with_multiple_beams": { @@ -360,8 +364,8 @@ def test_tree_search(self): [-3.0, -2.0] ), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions "token_details": [ - {"token_prob": 1.0, "token_rank": 0}, - {"token_prob": 1.0, "token_rank": 0}, + {"token_logprob": 0.0, "token_rank": 0}, + {"token_logprob": 0.0, "token_rank": 0}, ], }, }, @@ -390,9 +394,9 @@ def test_tree_search(self): path_selection.token_details, expected_result["token_details"] ): assert math.isclose( - token_details["token_prob"], - expected_token_details["token_prob"], - rel_tol=1e-3, + token_details["token_logprob"], + expected_token_details["token_logprob"], + abs_tol=1e-3, ), f"failed test_tree_search for test {test_name} on field token_details" assert ( token_details["token_rank"] == expected_token_details["token_rank"]