-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Blocking repeated ngrams during beam search #5216
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good so far. I do think having an end-to-end test is important. Is it feasible to at least make a test that assumes ngram_size
is 1?
One comment I have about the API for the Constraint
class is that having to deal with backpointers
in update_state()
is not ideal. Can we automatically handle reordering the state
list outside of update_state()
, either in the BeamSearch
class or the Constraint
base class?
In this version:
I removed "WIP" from the PR title because I think this version is complete to me, pending comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @danieldeutsch! Just a few more minor comments.
CHANGELOG.md
Outdated
@@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 | |||
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. | |||
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. | |||
- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. | |||
- Added a `Constrant` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Added a `Constrant` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`. | |
- Added a `Constraint` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`. |
allennlp/nn/beam_search.py
Outdated
class Constraint(Registrable): | ||
""" | ||
An abstract class that can be used to enforce constraints on the output predictions | ||
by manipulate the class log probabilities during beam search. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by manipulate the class log probabilities during beam search. | |
by manipulating the class log probabilities during beam search. |
allennlp/nn/beam_search.py
Outdated
|
||
The `apply()` method should manipulate the `class_log_probabilities` in place to enforce the constraint | ||
for this step of beam search. For instance, it may prevent a specific class from being selected by setting | ||
the corresponding log probability to `-inf` (by using `min_value_of_dtype(class_log_probabilities.dtype)`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the corresponding log probability to `-inf` (by using `min_value_of_dtype(class_log_probabilities.dtype)`). | |
the corresponding log probability to a negligible value by using `min_value_of_dtype(class_log_probabilities.dtype)`, which is essentially equivalent to `-inf`. |
allennlp/nn/beam_search.py
Outdated
for constraint, constraint_state in zip(self.constraints, constraint_states): | ||
# shape: (batch_size, beam_size, num_classes) | ||
reshaped_class_log_probabilities = class_log_probabilities.view( | ||
batch_size, self.beam_size, -1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to put the reshape outside of the loop?
for constraint, constraint_state in zip(self.constraints, constraint_states): | |
# shape: (batch_size, beam_size, num_classes) | |
reshaped_class_log_probabilities = class_log_probabilities.view( | |
batch_size, self.beam_size, -1 | |
) | |
# shape: (batch_size, beam_size, num_classes) | |
reshaped_class_log_probabilities = class_log_probabilities.view( | |
batch_size, self.beam_size, -1 | |
) | |
for constraint, constraint_state in zip(self.constraints, constraint_states): |
allennlp/nn/beam_search.py
Outdated
- `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the | ||
log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. | ||
|
||
The `apply()` method should manipulate the `class_log_probabilities` in place to enforce the constraint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's more natural for the apply()
method to return new class_log_probabilities
. It's a little easier to reason about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @danieldeutsch!
Thanks @epwalsh for finishing it. I forgot about it over the holiday |
backpointer = last_backpointer[i, j].item() | ||
batch_state.append(copy.deepcopy(state[i][backpointer])) | ||
new_state.append(batch_state) | ||
return new_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danieldeutsch @epwalsh Sorry, I know this PR is closed but I have a question that's easiest to ask here because of the existing context.
It seems that last_backpointer
, which is passed to _copy_state
by update_state
will always be None
(its never provided to update_state
in BeamSearch
). That would mean that backpointer
will always by 0
and then then it will always be state[i][0]
being copied, regardless of the actual timestep. Isn't this a problem? The tests don't catch this because they provide backpointer
to update
manually (see my other comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it looks like you are right to me. I think the fix should be to pass backpointer
here, right?
allennlp/allennlp/nn/beam_search.py
Lines 1095 to 1098 in 5b1da90
for i, constraint in enumerate(self.constraints): | |
constraint_states[i] = constraint.update_state( | |
constraint_states[i], restricted_predicted_classes | |
) |
], | ||
] | ||
predictions = torch.LongTensor([[5, 6], [0, 3]]) | ||
backpointers = torch.LongTensor([[1, 1], [0, 1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backpointers
provided here to update
, but not in BeamSearch
(see my other comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would need to retrace the beam search that I used to write this test, but my guess is that this is not caught because the backpointer is always 0
allennlp/tests/nn/beam_search_test.py
Line 774 in 5b1da90
def test_repeated_ngram_blocking_end_to_end(self): |
* Implementing blocking repeated ngrams * Adding comment * Adding unit tests for the end to end beam search * Renaming class * Adding comment about function * Simplifying indexing to variable * Refactoring the state copying into the class * Reformatting * Editing changelog * fix line too long * comments * doc updates Co-authored-by: Pete <petew@allenai.org> Co-authored-by: epwalsh <epwalsh10@gmail.com>
Closes #5205.
Changes proposed in this pull request:
Constraint
abstract class for enforcing constraints during beam searchNGramBlockingConstraint
class to prevent decoding repeated ngrams.This is a work in progress and maybe close to finishing, so I wanted to get feedback. @epwalsh, any thoughts on
Constraint
interface? I limited the methods to just those which I needed to implement the ngram blocking