Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Abstract TGA candidate ranking and fix ranking for BART #3455

Merged
merged 8 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions parlai/agents/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,61 +180,3 @@ def _get_initial_decoder_input(
.expand(bsz * beam_size, 2)
.to(dev)
)

def compute_loss(self, batch, return_output=False):
"""
Override TGA.compute_loss to ignore start token.
"""
if batch.label_vec is None:
raise ValueError('Cannot compute loss without a label.')
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output

if scores.size(1) != batch.label_vec.size(1):
# ignore start
scores = scores[:, 1:, :]
preds = preds[:, 1:]

score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
)
# actually do backwards loss
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
if return_output:
return (loss, model_output)
else:
return loss

def _construct_token_losses(self, labels, model_output):
"""
Override TGA._construct_token_losses to ignore start token.
"""
# Get non-aggregated losses
scores, _, _ = model_output
scores = scores[:, 1:, :] # ignore start token
score_view = scores.reshape(-1, scores.size(-1))
losses = self.criterion(score_view, labels.view(-1)).view(len(labels), -1)

# Zip decoded tokens with losses
token_losses = []
for i, label in enumerate(labels):
token_losses.append(
list(
zip(
[self.dict[token] for token in label.tolist()],
losses[i].tolist(),
)
)
)
return token_losses
13 changes: 13 additions & 0 deletions parlai/agents/bart/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def output(self, tensor: torch.Tensor) -> torch.Tensor:
"""
# project back to vocabulary
output = F.linear(tensor, self.embeddings.weight)

return output

def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
Expand Down Expand Up @@ -71,3 +72,15 @@ def reorder_decoder_incremental_state(
incr_state_l['self_attn']['prev_mask'] = self_attn_mask[:, -1:, :]

return super().reorder_decoder_incremental_state(incremental_state, inds)

def decode_forced(self, encoder_states, ys):
"""
Override to cut off score for start token.
"""
logits, preds = super().decode_forced(encoder_states, ys)
# ignore start
if logits.size(1) != ys.size(1):
logits = logits[:, 1:, :]
preds = preds[:, 1:]

return logits, preds
61 changes: 38 additions & 23 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def compute_loss(self, batch, return_output=False):
raise ValueError('Cannot compute loss without a label.')
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.view(-1, scores.size(-1))
score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
Expand Down Expand Up @@ -775,7 +775,7 @@ def train_step(self, batch):
def _construct_token_losses(self, labels, model_output):
# Get non-aggregated losses
scores, _, _ = model_output
score_view = scores.view(-1, scores.size(-1))
score_view = scores.reshape(-1, scores.size(-1))
losses = self.criterion(score_view, labels.view(-1)).view(len(labels), -1)

# Zip decoded tokens with losses
Expand Down Expand Up @@ -864,6 +864,37 @@ def _add_generation_metrics(self, batch, preds):
"""
pass

def _rank_eval_label_candidates(self, batch, batchsize):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a public method?

"""
Rank label_candidates during eval_step.

Can be overridden to allow for different ways of ranking candidates. Must have
`--rank-candidates` set to True. By default, we roughly compute PPL to rank the
candidates.
"""
# compute roughly ppl to rank candidates
cand_choices = []
cand_choices_scores = []
encoder_states = self.model.encoder(*self._encoder_input(batch))
for i in range(batchsize):
num_cands = len(batch.candidate_vecs[i])
enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands)
cands, _ = self._pad_tensor(batch.candidate_vecs[i])
scores, _ = self.model.decode_forced(enc, cands)
score_view = scores.reshape(num_cands * cands.size(1), -1)
cand_losses = F.cross_entropy(
score_view, cands.view(-1), reduction='none'
).view(num_cands, cands.size(1))
# now cand_losses is cands x seqlen size, but we still need to
# check padding and such
mask = (cands != self.NULL_IDX).float()
cand_scores = (cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
sorted_scores, ordering = cand_scores.sort()
cand_choices.append([batch.candidates[i][o] for o in ordering])
cand_choices_scores.append(sorted_scores.tolist())

return cand_choices, cand_choices_scores

def eval_step(self, batch):
"""
Evaluate a single batch of examples.
Expand Down Expand Up @@ -907,34 +938,18 @@ def eval_step(self, batch):
continue

cand_choices = None
# TODO: abstract out the scoring here
cand_scores = None
if self.rank_candidates:
# compute roughly ppl to rank candidates
cand_choices = []
encoder_states = self.model.encoder(*self._encoder_input(batch))
for i in range(bsz):
num_cands = len(batch.candidate_vecs[i])
enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands)
cands, _ = self._pad_tensor(batch.candidate_vecs[i])
scores, _ = self.model.decode_forced(enc, cands)
cand_losses = F.cross_entropy(
scores.view(num_cands * cands.size(1), -1),
cands.view(-1),
reduction='none',
).view(num_cands, cands.size(1))
# now cand_losses is cands x seqlen size, but we still need to
# check padding and such
mask = (cands != self.NULL_IDX).float()
cand_scores = (cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
_, ordering = cand_scores.sort()
cand_choices.append([batch.candidates[i][o] for o in ordering])
cand_choices, cand_scores = self._rank_eval_label_candidates(batch, bsz)

text = [self._v2t(p) for p in preds] if preds is not None else None
if text and self.compute_tokenized_bleu:
# compute additional bleu scores
self._compute_fairseq_bleu(batch, preds)
self._compute_nltk_bleu(batch, text)
retval = Output(text, cand_choices, token_losses=token_losses)
retval = Output(
text, cand_choices, token_losses=token_losses, cand_scores=cand_scores
)
if not self.skip_generation:
retval.beam_texts = beam_texts
return retval
Expand Down