Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Abstract TGA candidate ranking and fix ranking for BART #3455

Merged
merged 8 commits into from
Feb 22, 2021

Conversation

emilydinan
Copy link
Contributor

Patch description
Abstract Torch Generator Agent candidate ranking ability to a helper function.

Then, override this helper function in BART so that we may delete the score for the start token.

Testing steps
for BART:

parlai em -t convai2 --rank-candidates True -m bart --skip-generation True

for non-BART:

parlai em -t convai2 --rank-candidates True -mf zoo:blender/blender_90M/model --skip-generation Tru

cands, _ = self._pad_tensor(batch.candidate_vecs[i])
scores, _ = self.model.decode_forced(enc, cands)
# ignore the score for the start token
scores = scores[:, 1:, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only line that is different?

I'm wondering now if we can solve all of these issues by just overriding BartModel.output to do this computation...

that would allow us to get rid of the duplicate compute_loss and _construct_token_losses functions as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other change is that it needs reshaping...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you just took the reshape out of the cross_entropy call? i.e. that's already in the base call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not quite -- in the original function it has a call to scores.view(num_cands * cands.size(1), -1), but this breaks with BART unless we call reshape. i think it would probably be fine to change the view call to reshape instead in TGA, but this is a slightly more expensive operation (copy instead of view). what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since people don't seem to use this functionality too often i'll just change

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah given that this is rarely, if ever, used, it'll be better to just make them both reshape

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 awesome!

2 things:

  1. nit: a couple of lint errors
  2. Let's wait for BART tests to pass prior to merging. did you happen to run them locally?

@emilydinan
Copy link
Contributor Author

🚀 awesome!

2 things:

  1. nit: a couple of lint errors
  2. Let's wait for BART tests to pass prior to merging. did you happen to run them locally?

ran bart locally but looks like distillbart, looking into it

@emilydinan
Copy link
Contributor Author

@EricMichaelSmith can you take a look at the changes to distillbart when you get a chance?

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice to see the cleanup of BART. Did you verify an older checkpoint didn't have results change?

@@ -864,6 +864,37 @@ def _add_generation_metrics(self, batch, preds):
"""
pass

def _rank_eval_label_candidates(self, batch, batchsize):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a public method?

Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, minor comment

@@ -338,6 +341,9 @@ def _perform_forward_passes(self, batch: Batch) -> ForwardPassOutputs:
mask = self._manipulate_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe mask should now be score_mask to differentiate it from decoder_mask?

@emilydinan
Copy link
Contributor Author

It's nice to see the cleanup of BART. Did you verify an older checkpoint didn't have results change?

yep! tried training for a bit on convai2 and comparing results

@emilydinan
Copy link
Contributor Author

Circle CI hanging on self chat tests, I ran locally and they passed
Screen Shot 2021-02-22 at 1 21 37 PM

@EricMichaelSmith
Copy link
Contributor

Circle CI hanging on self chat tests, I ran locally and they passed
Screen Shot 2021-02-22 at 1 21 37 PM

hmm, it looks like tests/crowdsourcing/tasks/model_chat/test_model_image_chat.py::TestModelImageChat::test_base_task in the crowdsourcing CI check is hanging too - it might be good to run that locally as well. Sometimes there is flakiness with the crowdsourcing CI tests (that's an issue that Jack and I know about), but it might be worth checking that this PR doesn't affect that test in some way

@emilydinan emilydinan merged commit 7293ffe into master Feb 22, 2021
@emilydinan emilydinan deleted the tga_ranker branch February 22, 2021 19:18
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants