-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[TGA] Abstract TGA candidate ranking and fix ranking for BART #3455
Conversation
parlai/agents/bart/bart.py
Outdated
cands, _ = self._pad_tensor(batch.candidate_vecs[i]) | ||
scores, _ = self.model.decode_forced(enc, cands) | ||
# ignore the score for the start token | ||
scores = scores[:, 1:, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the only line that is different?
I'm wondering now if we can solve all of these issues by just overriding BartModel.output
to do this computation...
that would allow us to get rid of the duplicate compute_loss
and _construct_token_losses
functions as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the other change is that it needs reshaping...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like you just took the reshape
out of the cross_entropy
call? i.e. that's already in the base call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not quite -- in the original function it has a call to scores.view(num_cands * cands.size(1), -1)
, but this breaks with BART unless we call reshape. i think it would probably be fine to change the view call to reshape instead in TGA, but this is a slightly more expensive operation (copy instead of view). what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since people don't seem to use this functionality too often i'll just change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah given that this is rarely, if ever, used, it'll be better to just make them both reshape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 awesome!
2 things:
- nit: a couple of lint errors
- Let's wait for BART tests to pass prior to merging. did you happen to run them locally?
ran bart locally but looks like distillbart, looking into it |
@EricMichaelSmith can you take a look at the changes to distillbart when you get a chance? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's nice to see the cleanup of BART. Did you verify an older checkpoint didn't have results change?
parlai/core/torch_generator_agent.py
Outdated
@@ -864,6 +864,37 @@ def _add_generation_metrics(self, batch, preds): | |||
""" | |||
pass | |||
|
|||
def _rank_eval_label_candidates(self, batch, batchsize): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be a public method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable, minor comment
@@ -338,6 +341,9 @@ def _perform_forward_passes(self, batch: Batch) -> ForwardPassOutputs: | |||
mask = self._manipulate_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe mask
should now be score_mask
to differentiate it from decoder_mask
?
yep! tried training for a bit on convai2 and comparing results |
Patch description
Abstract Torch Generator Agent candidate ranking ability to a helper function.
Then, override this helper function in BART so that we may delete the score for the start token.
Testing steps
for BART:
for non-BART: