Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Abstract TGA candidate ranking and fix ranking for BART #3455

Merged
merged 8 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 0 additions & 60 deletions parlai/agents/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import History
from parlai.core.torch_generator_agent import PPLMetric
from parlai.core.metrics import AverageMetric
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager
from parlai.zoo.bart.build import download, CONVERSION_ARGS, BART_ARGS
Expand Down Expand Up @@ -180,61 +178,3 @@ def _get_initial_decoder_input(
.expand(bsz * beam_size, 2)
.to(dev)
)

def compute_loss(self, batch, return_output=False):
"""
Override TGA.compute_loss to ignore start token.
"""
if batch.label_vec is None:
raise ValueError('Cannot compute loss without a label.')
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output

if scores.size(1) != batch.label_vec.size(1):
# ignore start
scores = scores[:, 1:, :]
preds = preds[:, 1:]

score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
)
# actually do backwards loss
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
if return_output:
return (loss, model_output)
else:
return loss

def _construct_token_losses(self, labels, model_output):
"""
Override TGA._construct_token_losses to ignore start token.
"""
# Get non-aggregated losses
scores, _, _ = model_output
scores = scores[:, 1:, :] # ignore start token
score_view = scores.reshape(-1, scores.size(-1))
losses = self.criterion(score_view, labels.view(-1)).view(len(labels), -1)

# Zip decoded tokens with losses
token_losses = []
for i, label in enumerate(labels):
token_losses.append(
list(
zip(
[self.dict[token] for token in label.tolist()],
losses[i].tolist(),
)
)
)
return token_losses
13 changes: 13 additions & 0 deletions parlai/agents/bart/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def output(self, tensor: torch.Tensor) -> torch.Tensor:
"""
# project back to vocabulary
output = F.linear(tensor, self.embeddings.weight)

return output

def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
Expand Down Expand Up @@ -71,3 +72,15 @@ def reorder_decoder_incremental_state(
incr_state_l['self_attn']['prev_mask'] = self_attn_mask[:, -1:, :]

return super().reorder_decoder_incremental_state(incremental_state, inds)

def decode_forced(self, encoder_states, ys):
"""
Override to cut off score for start token.
"""
logits, preds = super().decode_forced(encoder_states, ys)
# ignore start
if logits.size(1) != ys.size(1):
logits = logits[:, 1:, :]
preds = preds[:, 1:]

return logits, preds
61 changes: 38 additions & 23 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def compute_loss(self, batch, return_output=False):
raise ValueError('Cannot compute loss without a label.')
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.view(-1, scores.size(-1))
score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
Expand Down Expand Up @@ -775,7 +775,7 @@ def train_step(self, batch):
def _construct_token_losses(self, labels, model_output):
# Get non-aggregated losses
scores, _, _ = model_output
score_view = scores.view(-1, scores.size(-1))
score_view = scores.reshape(-1, scores.size(-1))
losses = self.criterion(score_view, labels.view(-1)).view(len(labels), -1)

# Zip decoded tokens with losses
Expand Down Expand Up @@ -864,6 +864,37 @@ def _add_generation_metrics(self, batch, preds):
"""
pass

def rank_eval_label_candidates(self, batch, batchsize):
"""
Rank label_candidates during eval_step.

Can be overridden to allow for different ways of ranking candidates. Must have
`--rank-candidates` set to True. By default, we roughly compute PPL to rank the
candidates.
"""
# compute roughly ppl to rank candidates
cand_choices = []
cand_choices_scores = []
encoder_states = self.model.encoder(*self._encoder_input(batch))
for i in range(batchsize):
num_cands = len(batch.candidate_vecs[i])
enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands)
cands, _ = self._pad_tensor(batch.candidate_vecs[i])
scores, _ = self.model.decode_forced(enc, cands)
score_view = scores.reshape(num_cands * cands.size(1), -1)
cand_losses = F.cross_entropy(
score_view, cands.view(-1), reduction='none'
).view(num_cands, cands.size(1))
# now cand_losses is cands x seqlen size, but we still need to
# check padding and such
mask = (cands != self.NULL_IDX).float()
cand_scores = (cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
sorted_scores, ordering = cand_scores.sort()
cand_choices.append([batch.candidates[i][o] for o in ordering])
cand_choices_scores.append(sorted_scores.tolist())

return cand_choices, cand_choices_scores

def eval_step(self, batch):
"""
Evaluate a single batch of examples.
Expand Down Expand Up @@ -907,34 +938,18 @@ def eval_step(self, batch):
continue

cand_choices = None
# TODO: abstract out the scoring here
cand_scores = None
if self.rank_candidates:
# compute roughly ppl to rank candidates
cand_choices = []
encoder_states = self.model.encoder(*self._encoder_input(batch))
for i in range(bsz):
num_cands = len(batch.candidate_vecs[i])
enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands)
cands, _ = self._pad_tensor(batch.candidate_vecs[i])
scores, _ = self.model.decode_forced(enc, cands)
cand_losses = F.cross_entropy(
scores.view(num_cands * cands.size(1), -1),
cands.view(-1),
reduction='none',
).view(num_cands, cands.size(1))
# now cand_losses is cands x seqlen size, but we still need to
# check padding and such
mask = (cands != self.NULL_IDX).float()
cand_scores = (cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
_, ordering = cand_scores.sort()
cand_choices.append([batch.candidates[i][o] for o in ordering])
cand_choices, cand_scores = self.rank_eval_label_candidates(batch, bsz)

text = [self._v2t(p) for p in preds] if preds is not None else None
if text and self.compute_tokenized_bleu:
# compute additional bleu scores
self._compute_fairseq_bleu(batch, preds)
self._compute_nltk_bleu(batch, text)
retval = Output(text, cand_choices, token_losses=token_losses)
retval = Output(
text, cand_choices, token_losses=token_losses, cand_scores=cand_scores
)
if not self.skip_generation:
retval.beam_texts = beam_texts
return retval
Expand Down
21 changes: 15 additions & 6 deletions projects/anti_scaling/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ForwardPassOutputs(AttrDict):
"""

mask: torch.BoolTensor
decoder_mask: torch.BoolTensor
tokens_per_example: torch.Tensor
num_tokens: torch.Tensor
context_mask: torch.BoolTensor
Expand All @@ -76,6 +77,7 @@ class ForwardPassOutputs(AttrDict):
def __init__(
self,
mask,
decoder_mask,
tokens_per_example,
num_tokens,
context_mask,
Expand All @@ -96,6 +98,7 @@ def __init__(
):
super().__init__(
mask=mask,
decoder_mask=decoder_mask,
tokens_per_example=tokens_per_example,
num_tokens=num_tokens,
context_mask=context_mask,
Expand Down Expand Up @@ -338,6 +341,9 @@ def _perform_forward_passes(self, batch: Batch) -> ForwardPassOutputs:
mask = self._manipulate_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe mask should now be score_mask to differentiate it from decoder_mask?

mask=mask, student_scores=student_scores, batch=batch
)
decoder_mask = self._manipulate_mask(
mask=mask, student_scores=student_embedding_outputs["decoder"], batch=batch
)

# Record teacher accuracy
teacher_acc = ((student_preds == teacher_preds) * mask).sum(dim=-1)
Expand All @@ -348,6 +354,7 @@ def _perform_forward_passes(self, batch: Batch) -> ForwardPassOutputs:

return ForwardPassOutputs(
mask=mask,
decoder_mask=decoder_mask,
tokens_per_example=tokens_per_example,
num_tokens=num_tokens,
context_mask=context_mask,
Expand Down Expand Up @@ -502,7 +509,7 @@ def _get_embedding_losses(
dec_emb_loss, dec_emb_loss_per_example = self._get_component_embedding_loss(
student_emb_output=fwd_pass.student_embedding_outputs['decoder'],
teacher_emb_output=fwd_pass.teacher_embedding_outputs['decoder'],
mask=fwd_pass.mask,
mask=fwd_pass.decoder_mask,
num_tokens=fwd_pass.num_tokens,
)
self.record_local_metric(
Expand Down Expand Up @@ -559,7 +566,7 @@ def _get_hidden_losses(
dec_hidden_loss, dec_hidden_loss_per_example = self._get_component_hidden_loss(
student_hidden_states=fwd_pass.student_hidden_states['decoder'],
teacher_hidden_states=fwd_pass.teacher_hidden_states['decoder'],
mask=fwd_pass.mask,
mask=fwd_pass.decoder_mask,
num_tokens=fwd_pass.num_tokens,
mapped_layers=self.mapped_dec_layers,
)
Expand Down Expand Up @@ -631,7 +638,7 @@ def _get_attention_losses(
dec_self_attn_loss = self._get_and_record_component_attention_loss(
student_attention_matrices=fwd_pass.student_attention_matrices['decoder'],
teacher_attention_matrices=fwd_pass.teacher_attention_matrices['decoder'],
mask=fwd_pass.mask,
mask=fwd_pass.decoder_mask,
tokens_per_example=fwd_pass.tokens_per_example,
num_tokens=fwd_pass.num_tokens,
mapped_layers=self.mapped_dec_layers,
Expand All @@ -641,7 +648,7 @@ def _get_attention_losses(
enc_dec_attn_loss = self._get_and_record_component_attention_loss(
student_attention_matrices=fwd_pass.student_attention_matrices['decoder'],
teacher_attention_matrices=fwd_pass.teacher_attention_matrices['decoder'],
mask=fwd_pass.mask,
mask=fwd_pass.decoder_mask,
tokens_per_example=fwd_pass.tokens_per_example,
num_tokens=fwd_pass.num_tokens,
mapped_layers=self.mapped_dec_layers,
Expand Down Expand Up @@ -1038,9 +1045,11 @@ def _manipulate_mask(
) -> torch.BoolTensor:
"""
Add one extra (masked-out) token to the mask, for compatibility with BART.

Only necessary when examining decoder outputs directly.
"""
assert student_scores.size(1) == batch.label_vec.size(1) + 1
mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1)
if student_scores.size(1) == batch.label_vec.size(1) + 1:
mask = torch.cat([mask.new_zeros([mask.size(0), 1]), mask], dim=1)
return mask


Expand Down