Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

More flexible token metadata logging #4427

Merged
merged 4 commits into from
Mar 25, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dont truncate nucleus sampling distr and output normalized probs inst…
…ead of logprobs
c-flaherty committed Mar 16, 2022
commit 58fd67c6f055db3f5db1219dee0626c3509d6bd4
20 changes: 11 additions & 9 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
@@ -1692,7 +1692,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection

token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
tok_score = tok_scores[0].item()
tok_score = torch.softmax(logprobs.view(-1), dim=-1)[tok_ids].item()
c-flaherty marked this conversation as resolved.
Show resolved Hide resolved
tok_rank = 0
token_details: Optional[List[_PathSelectionTokenDetails]] = [
{"token_score": tok_score, "token_rank": tok_rank}
@@ -1732,13 +1732,14 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection

token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
probs = torch.softmax(logprobs, dim=-1)
tok_scores = (
torch.index_select(logprobs, 0, hyp_ids)
torch.index_select(probs, 0, hyp_ids)
.gather(1, tok_ids.unsqueeze(1))
.view(-1)
)
tok_ranks = (
logprobs.argsort(1, descending=True)
probs.argsort(1, descending=True)
.argsort(1)
.view(-1)
.gather(0, best_idxs)
@@ -1811,7 +1812,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection

token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
tok_scores = scores.view(-1).cpu().numpy()
tok_scores = probs[hyp_ids, choices].view(-1).cpu().numpy()
tok_ranks = choices.view(-1).cpu().numpy()
token_details = []

@@ -1852,18 +1853,19 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection
# The subtraction here is to get the exclusive prefix sum,
# to guarantee the first element is not masked
mask = (sprobs.cumsum(dim=-1) - sprobs) >= self.p
sprobs[mask] = 0
sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1))
choices = torch.multinomial(sprobs, 1)[:, 0]
trunc_sprobs = sprobs.detach().clone()
c-flaherty marked this conversation as resolved.
Show resolved Hide resolved
trunc_sprobs[mask] = 0
trunc_sprobs.div_(trunc_sprobs.sum(dim=-1).unsqueeze(1))
choices = torch.multinomial(trunc_sprobs, 1)[:, 0]
hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
tok_ids = sinds[hyp_ids, choices]
# Convert back to logspace.
scores = sprobs[hyp_ids, choices].log()
scores = trunc_sprobs[hyp_ids, choices].log()
best_scores = prior_scores.expand_as(scores) + scores

token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
tok_scores = scores.view(-1).cpu().numpy()
tok_scores = sprobs[hyp_ids, choices].view(-1).cpu().numpy()
tok_ranks = choices.view(-1).cpu().numpy()
token_details = []

37 changes: 16 additions & 21 deletions tests/test_tga.py
Original file line number Diff line number Diff line change
@@ -191,22 +191,16 @@ def test_token_level_loss_logging(self):
'beam': {
'text_token_info': [
('__start__', {"token_score": 0.0, "token_rank": 1}),
('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}),
(
'__end__',
{"token_score": -1.1920922133867862e-06, "token_rank": 0},
),
('5', {"token_score": 0.999, "token_rank": 0}),
('__end__', {"token_score": 0.999, "token_rank": 0}),
],
'extra_args': ['--beam-size', '3'],
},
'greedy': {
'text_token_info': [
('__start__', {"token_score": 0.0, "token_rank": 1}),
('5', {"token_score": -2.5510462364763953e-05, "token_rank": 0}),
(
'__end__',
{"token_score": -1.1920922133867862e-06, "token_rank": 0},
),
('5', {"token_score": 0.999, "token_rank": 0}),
('__end__', {"token_score": 0.999, "token_rank": 0}),
],
'extra_args': [],
},
@@ -250,6 +244,7 @@ def test_token_level_loss_logging(self):
"token_score"
],
tok_data[1]["token_score"],
rel_tol=1e-3,
), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}"
assert math.isclose(
gold_data[inference_type]['text_token_info'][i][1][
@@ -271,7 +266,7 @@ def test_tree_search(self):
"hypothesis_ids": torch.LongTensor([0]),
"token_ids": torch.LongTensor([2]),
"scores": torch.Tensor([-0.6]),
"token_details": [{"token_score": -0.1, "token_rank": 0}],
"token_details": [{"token_score": 0.3800, "token_rank": 0}],
},
},
"beam_with_one_beam": {
@@ -282,7 +277,7 @@ def test_tree_search(self):
"hypothesis_ids": torch.LongTensor([0]),
"token_ids": torch.LongTensor([2]),
"scores": torch.Tensor([-0.6]),
"token_details": [{"token_score": -0.1, "token_rank": 0}],
"token_details": [{"token_score": 0.3800, "token_rank": 0}],
},
},
"beam_with_multiple_beams": {
@@ -297,8 +292,8 @@ def test_tree_search(self):
"token_ids": torch.LongTensor([2, 3]),
"scores": torch.Tensor([-0.7, -0.8]),
"token_details": [
{"token_score": -0.2, "token_rank": 0},
{"token_score": -0.3, "token_rank": 1},
{"token_score": 0.3567, "token_rank": 0},
{"token_score": 0.3228, "token_rank": 1},
],
},
},
@@ -312,7 +307,7 @@ def test_tree_search(self):
"hypothesis_ids": torch.LongTensor([0]),
"token_ids": torch.LongTensor([1]),
"scores": torch.Tensor([-3.5]),
"token_details": [{"token_score": -0.5, "token_rank": 0}],
"token_details": [{"token_score": 1.0, "token_rank": 0}],
},
},
"topk_with_multiple_beams": {
@@ -329,8 +324,8 @@ def test_tree_search(self):
"token_ids": torch.LongTensor([1, 2]),
"scores": torch.Tensor([-3.5, -2.6]),
"token_details": [
{"token_score": -0.5, "token_rank": 0},
{"token_score": -0.6, "token_rank": 0},
{"token_score": 1.0, "token_rank": 0},
{"token_score": 1.0, "token_rank": 0},
],
},
},
@@ -346,7 +341,7 @@ def test_tree_search(self):
"scores": torch.Tensor(
[-3.0]
), # the -0.5 logprob normalizes to 0 in truncated distribution
"token_details": [{"token_score": -0.0, "token_rank": 0}],
"token_details": [{"token_score": 1.0, "token_rank": 0}],
},
},
"nucleus_with_multiple_beams": {
@@ -365,8 +360,8 @@ def test_tree_search(self):
[-3.0, -2.0]
), # the -0.5, -0.6 logprobs normalize to 0 in truncated distributions
"token_details": [
{"token_score": -0.0, "token_rank": 0},
{"token_score": -0.0, "token_rank": 0},
{"token_score": 1.0, "token_rank": 0},
{"token_score": 1.0, "token_rank": 0},
],
},
},
@@ -397,7 +392,7 @@ def test_tree_search(self):
assert math.isclose(
token_details["token_score"],
expected_token_details["token_score"],
rel_tol=1e-5,
rel_tol=1e-3,
), f"failed test_tree_search for test {test_name} on field token_details"
assert (
token_details["token_rank"] == expected_token_details["token_rank"]