Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Logging token level losses at inference time (#4169)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-flaherty authored Jan 5, 2022
1 parent 6df9361 commit daa85bf
Show file tree
Hide file tree
Showing 13 changed files with 555 additions and 104 deletions.
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -235,26 +235,26 @@ commands:
- setupcuda
- fixgit
- restore_cache:
key: deps-20210722-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-2021222-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
- setup
- installdeps
- << parameters.more_installs >>
- save_cache:
key: deps-20210722-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-2021222-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
- findtests:
marker: << parameters.marker >>
- restore_cache:
key: data-20210722-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-2021222-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
- run:
name: Run tests
no_output_timeout: 60m
command: |
coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml
- save_cache:
key: data-20210722-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-2021222-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
paths:
- "~/ParlAI/data"
- codecov
Expand All @@ -271,12 +271,12 @@ commands:
- checkout
- fixgit
- restore_cache:
key: deps-20210722-bw-{{ checksum "requirements.txt" }}
key: deps-2021222-bw-{{ checksum "requirements.txt" }}
- setup
- installdeps
- installtorchgpu17
- save_cache:
key: deps-20210722-bw-{{ checksum "requirements.txt" }}
key: deps-2021222-bw-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
Expand Down
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def pytest_collection_modifyitems(config, items):
class_mapping[class_name].append(item)

test_groupings = list(class_mapping.keys())
random.Random(1337).shuffle(test_groupings)
random.Random(1339).shuffle(test_groupings)

filtered_tests = filter_tests_with_circleci(test_groupings)
new_items = []
Expand Down
4 changes: 2 additions & 2 deletions parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def check_hf_version(v: Tuple[int, int]) -> bool:
"""
Check that HF version is greater than 4.3
Check that HF version is greater than 4.3.
"""
main, sub = v
return main > 4 or (main == 4 and sub >= 3)
Expand Down Expand Up @@ -195,7 +195,7 @@ def _generate(
generation_params.update(overrides)

outputs = self.model.t5.generate(**generation_params)
outputs = [(outputs[i], 0) for i in range(outputs.size(0))]
outputs = [(outputs[i], 0, None) for i in range(outputs.size(0))]
return outputs, []


Expand Down
59 changes: 38 additions & 21 deletions parlai/agents/rag/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,10 @@ def rerank_beams(
self,
model: RagModel,
batch: Batch,
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]:
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]:
"""
Optionally rerank beams.
"""
Expand Down Expand Up @@ -423,8 +425,10 @@ def rerank_beams(
self,
model: RagModel,
batch: Batch,
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
) -> List[List[List[torch.LongTensor]]]:
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
) -> List[List[Tuple[torch.LongTensor, Optional[Dict]]]]:
"""
Rerank beams in RAG-Sequence, accounting for document probabilities as well.
Expand Down Expand Up @@ -517,7 +521,9 @@ def augment_batch_for_generation(self, batch: Batch, model: RagModel) -> Batch:
def fast_generation(
cls,
doc_indices: List[int],
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
doc_log_probs: torch.Tensor,
n_docs: int,
):
Expand All @@ -534,22 +540,26 @@ def fast_generation(
number of docs per example
:return sorted_hyps:
return list of (hyp, score) tuples, sorted by their score.
return list of (hyp, score, token metadata) tuples, sorted by their score.
"""
marginalized_hypos: Dict[str, List[torch.Tensor]] = {}
marginalized_hypos: Dict[
str, Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]
] = {}
for doc_idx in doc_indices:
doc_hypos = n_best_beam_preds_scores[doc_idx]
doc_score = doc_log_probs[doc_idx % n_docs]
for hypo, hypo_score in doc_hypos:
for hypo, hypo_score, token_metadata in doc_hypos:
score = hypo_score + doc_score
hypo_tokens = str(hypo.tolist())
if hypo_tokens in marginalized_hypos:
marginalised_hypo = marginalized_hypos[hypo_tokens]
marginalised_hypo[1] = torch.log(
marginalised_hypo[1].exp() + score.exp()
marginalised_hypo = (
marginalised_hypo[0],
torch.log(marginalised_hypo[1].exp() + score.exp()),
marginalised_hypo[2],
)
else:
marginalized_hypos[hypo_tokens] = [hypo, score]
marginalized_hypos[hypo_tokens] = (hypo, score, token_metadata)
sorted_by_score = sorted(marginalized_hypos.values(), key=lambda h: -h[1])
return sorted_by_score

Expand All @@ -560,7 +570,7 @@ def thorough_generation(
new_input: torch.LongTensor,
null_idx: int,
model: RagModel,
) -> List[Tuple[torch.LongTensor, torch.Tensor]]:
) -> List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]:
"""
Apply RAG-sequence thorough generation for a single batch item.
Expand All @@ -572,7 +582,7 @@ def thorough_generation(
input for the model
:return sorted_hyps:
return list of (hyp, score) tuples, sorted by their score.
return list of (hyp, score, token_metadata) tuples, sorted by their score.
"""
# deduplicate, exclude BOS Token
hyps = list({str(h.tolist()): h[1:] for h in hyps}.values()) # type: ignore
Expand All @@ -586,7 +596,7 @@ def thorough_generation(
new_ys.unsqueeze(1).unsqueeze(-1), scores.unsqueeze(1), null_idx
) # type: ignore
sorted_by_score = [
(hyps[idx], loss[idx]) for idx in loss.sort()[-1]
(hyps[idx], loss[idx], None) for idx in loss.sort()[-1]
] # sort ascending
return sorted_by_score

Expand Down Expand Up @@ -834,8 +844,8 @@ def reorder_decoder_incremental_state(
"""
For RAG Token, we send each decoder input through n_docs times.
Similarly to reordering the encoder states, we need to reorder according
to the documents dimensions.
Similarly to reordering the encoder states, we need to reorder according to the
documents dimensions.
"""
assert incremental_state is not None
incremental_state = fix_incremental_state(
Expand Down Expand Up @@ -866,8 +876,10 @@ def rerank_beams(
self,
model: RagModel,
batch: Batch,
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]:
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]:
"""
We don't re-rank beams for RAG Token.
"""
Expand Down Expand Up @@ -1169,8 +1181,10 @@ def rerank_beams(
self,
model: RagModel,
batch: Batch,
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]:
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]:
"""
Re-rank beams.
Expand All @@ -1183,7 +1197,9 @@ def rerank_beams(
Thorough decoding is identical RAG Sequence.
"""
new_n_best: List[List[Tuple[torch.LongTensor, torch.Tensor]]] = []
new_n_best: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
] = []
if self.turn_marginalize == 'doc_only' and not self.thorough:
# no doc log probs here; just re-sorting beams
input_turns_cnt = batch.input_turns_cnt
Expand All @@ -1197,6 +1213,7 @@ def rerank_beams(
new_beam = (
beam[0],
beam[1] * self.discount_factor ** (it - i - 1),
beam[2],
)
n_best_i.append(new_beam)
new_n_best.append(sorted(n_best_i, key=lambda x: -x[1]))
Expand Down
22 changes: 15 additions & 7 deletions parlai/agents/rag/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def _generate(
beam_size: int,
max_ts: int,
prefix_tokens: Optional[torch.LongTensor] = None,
) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]:
) -> Tuple[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch]
]:
"""
Override since T5 needs to call TGA generate.
"""
Expand Down Expand Up @@ -664,7 +666,9 @@ def _generate(
beam_size: int,
max_ts: int,
prefix_tokens: Optional[torch.LongTensor] = None,
) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]:
) -> Tuple[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch]
]:
"""
Override TGA._generate to potentially call ReGReT.
Expand All @@ -674,7 +678,7 @@ def _generate(
beam_preds_scores, _ = self._regret_generate(
batch, beam_size, self.regret_intermediate_maxlen, prefix_tokens
)
preds, _ = zip(*beam_preds_scores)
preds, _, _ = zip(*beam_preds_scores)
new_batch = self._regret_rebatchify(batch, preds) # type: ignore
gen_outs = self._rag_generate(new_batch, beam_size, max_ts, prefix_tokens)
else:
Expand All @@ -685,8 +689,10 @@ def _generate(
def _rerank_beams(
self,
batch: Batch,
n_best_beam_preds_scores: List[List[Tuple[torch.LongTensor, torch.Tensor]]],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor]]]:
n_best_beam_preds_scores: List[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]
],
) -> List[List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]]]:
"""
Optional rerank beams, according to RAG Model type.
Expand All @@ -710,7 +716,9 @@ def _rag_generate(
beam_size: int,
max_ts: int,
prefix_tokens: Optional[torch.LongTensor] = None,
) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]:
) -> Tuple[
List[Tuple[torch.LongTensor, torch.Tensor, Optional[Dict]]], List[TreeSearch]
]:
"""
Separate from _generate to handle regret.
"""
Expand Down Expand Up @@ -887,7 +895,7 @@ def get_model_output(self, batch: Batch) -> Tuple[Any, ...]:
beam_preds_scores, beams = self._regret_generate(
batch, self.beam_size, self.regret_intermediate_maxlen
)
regret_preds, _ = zip(*beam_preds_scores)
regret_preds, _, _ = zip(*beam_preds_scores)
new_batch = self._regret_rebatchify(batch, regret_preds) # type: ignore
regret_model_output = self.model(
*self._model_input(new_batch), ys=batch.label_vec
Expand Down
Loading

0 comments on commit daa85bf

Please sign in to comment.