Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix constraint bug in beam search, clean up tests (#5328)
Browse files Browse the repository at this point in the history
* fix constraint bug in beam search, clean up tests

* fix linting error

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
epwalsh and dirkgr authored Aug 3, 2021
1 parent ec3e294 commit b72bbfc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 107 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`.
- Fixed broken links in `allennlp.nn.initializers` docs.
- Fixed bug in `BeamSearch` where `last_backpointers` was not being passed to any `Constraint`s.
- `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer.
- `tqdm` lock is now set inside `MultiProcessDataLoading` when new workers are spawned to avoid contention when writing output.
- `ConfigurationError` is now pickleable.
Expand Down
6 changes: 4 additions & 2 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,9 @@ def _search(
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
# division as the tensor is a LongTensor.)
# shape: (batch_size, beam_size)
backpointer = restricted_beam_indices // self.per_node_beam_size
backpointer = torch.divide(
restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc"
)
backpointers.append(backpointer)

# Keep only the pieces of the state tensors corresponding to the
Expand All @@ -1094,7 +1096,7 @@ def _search(

for i, constraint in enumerate(self.constraints):
constraint_states[i] = constraint.update_state(
constraint_states[i], restricted_predicted_classes
constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
)

# Warn about "-inf" log probabilities if not using any constraints (negligible
Expand Down
200 changes: 95 additions & 105 deletions tests/nn/beam_search_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import numpy as np
import pytest
Expand All @@ -12,120 +12,98 @@
TopKSampler,
TopPSampler,
GumbelSampler,
SequenceLogProbabilityScorer,
LengthNormalizedSequenceLogProbabilityScorer,
RepeatedNGramBlockingConstraint,
StepFunctionTypeWithTimestep,
StepFunctionTypeNoTimestep,
)
from allennlp.common.params import Params
from allennlp.nn.util import min_value_of_dtype


# fmt: off
transition_probabilities = torch.tensor(
[
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # start token -> jth token
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # 1st token -> jth token
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # 2nd token -> jth token
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # ...
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ...
[0.2, 0.1, 0.2, 0.2, 0.2, 0.3],
] # end token -> jth token
[ # START 1 2 3 4 END
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # 1 -> j
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # 2 -> j
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # 3 -> j
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter)
]
)

# A transition matrix that favors shorter sequences over longer ones
short_sequence_transition_probabilities = torch.tensor(
[
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # start token -> jth token
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1st token -> jth token
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2nd token -> jth token
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # ...
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ...
[0.2, 0.1, 0.2, 0.2, 0.2, 0.3],
] # end token -> jth token
[ # START 1 2 3 4 END
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # START -> j
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1 -> j
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2 -> j
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # 3 -> j
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter)
]
)

# A transition matrix that favors repeated ngrams
repeated_ngram_transition_probabilities = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # start token -> jth token
[0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1st token -> jth token
[0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2nd token -> jth token
[0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # ...
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # not used
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
] # end token -> jth token
repeated_ngram_transition_probabilities_0 = torch.tensor(
[ # START 1 2 3 4 END
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # START -> j
[0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1 -> j
[0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2 -> j
[0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # 3 -> j
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 4 -> j (not used)
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # END -> j (doesn't matter)
]
)

# Another transition matrix that favors repeated ngrams
repeated_ngram_transition_probabilities_1 = torch.tensor(
[ # START 1 2 3 4 END
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j
[0.0, 0.4, 0.3, 0.2, 0.1, 0.1], # 1 -> j
[0.0, 0.0, 0.4, 0.3, 0.2, 0.1], # 2 -> j
[0.0, 0.0, 0.3, 0.4, 0.2, 0.1], # 3 -> j
[0.0, 0.0, 0.2, 0.3, 0.4, 0.1], # 4 -> j
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter)
]
)
# fmt: on

log_probabilities = torch.log(
torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]])
)


def take_step_no_timestep(
last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take decoding step.
This is a simple function that defines how probabilities are computed for the
next time step during the beam search.
We use a simple target vocabulary of size 6. In this vocabulary, index 0 represents
the start token, and index 5 represents the end token. The transition probability
from a state where the last predicted token was token `j` to new token `i` is
given by the `(i, j)` element of the matrix `transition_probabilities`.
"""
log_probs_list = []
for last_token in last_predictions:
log_probs = torch.log(transition_probabilities[last_token.item()])
log_probs_list.append(log_probs)

return torch.stack(log_probs_list), state

def get_step_function(
transition_matrix: torch.Tensor, with_timestep: bool = False
) -> Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep]:
def _step_function(
last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
log_probs_list = []
for last_token in last_predictions:
log_probs = torch.log(transition_matrix[last_token.item()])
log_probs_list.append(log_probs)

def take_step_with_timestep(
last_predictions: torch.Tensor,
state: Dict[str, torch.Tensor],
timestep: int,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
return take_step_no_timestep(last_predictions, state)
return torch.stack(log_probs_list), state

if not with_timestep:
return _step_function

def take_short_sequence_step(
last_predictions: torch.Tensor,
state: Dict[str, torch.Tensor],
timestep: int,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take decoding step.
def _step_function_with_timestep(
last_predictions: torch.Tensor,
state: Dict[str, torch.Tensor],
timestep: int,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
return _step_function(last_predictions, state)

This method is the same as `take_step_no_timestep` except it uses the
`short_sequence_transition_probabilities` transitions instead of `transition_probabilities`
"""
log_probs_list = []
for last_token in last_predictions:
log_probs = torch.log(short_sequence_transition_probabilities[last_token.item()])
log_probs_list.append(log_probs)
return _step_function_with_timestep

return torch.stack(log_probs_list), state


def take_repeated_ngrams_step(
last_predictions: torch.Tensor,
state: Dict[str, torch.Tensor],
timestep: int,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take decoding step.
This method is the same as `take_step_no_timestep` except it uses the
`short_sequence_transition_probabilities` transitions instead of `transition_probabilities`
"""
log_probs_list = []
for last_token in last_predictions:
log_probs = torch.log(repeated_ngram_transition_probabilities[last_token.item()])
log_probs_list.append(log_probs)

return torch.stack(log_probs_list), state
take_step_no_timestep = get_step_function(transition_probabilities)
take_step_with_timestep = get_step_function(transition_probabilities, with_timestep=True)
take_short_sequence_step = get_step_function(short_sequence_transition_probabilities)


class BeamSearchTest(AllenNlpTestCase):
Expand Down Expand Up @@ -575,11 +553,6 @@ def test_gumbel_sampler(self):
assert all([x >= 0 and x < 4 for x in indices[0]])
assert all([x > 1 and x <= 5 for x in indices[1]])

def test_sequence_log_prob_scorer(self):
# SequenceLogProbabilityScorer is the default, so manually setting the
# sequence scorer shouldn't actually change anything
self.beam_search.sequence_scorer = SequenceLogProbabilityScorer()

def test_length_normalized_sequence_log_prob_scorer(self):
"""
Tests to ensure the sequences are normalized by the correct values. The end token is
Expand Down Expand Up @@ -708,7 +681,7 @@ def test_repeated_ngram_blocking_constraint_update_state(self):

def test_take_repeated_ngram_step(self):
"""
Tests to ensure the top-k from the short_sequence_transition_probabilities
Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0`
transition matrix is expected. The transitions are:
- p(1|start) = 1.0
Expand Down Expand Up @@ -761,23 +734,19 @@ def test_take_repeated_ngram_step(self):
0.4 * 1e-9: [1, 2, 3, 1, end]
0.36 * 1e-9: [1, 3, 1, 3, end]
"""
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
self.beam_search.beam_size = 2
self.beam_search.max_steps = 5
expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]])
expected_log_probs = np.log(np.array([0.36, 0.24]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_repeated_ngrams_step,
take_step=step_function,
)

def test_repeated_ngram_blocking_end_to_end(self):
"""
This test checks to make sure the `RepeatedNGramBlockingConstraint` successfully blocks ngrams.
It works by blocking ngrams of different sizes and ensures that the result of beam search
is correctly changed. We rely on the beam search trace for `repeated_ngram_transition_probabilities`
in `test_take_repeated_ngram_step`.
"""
def test_repeated_ngram_blocking_end_to_end_unigrams(self):
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
self.beam_search.beam_size = 2

# Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place
Expand All @@ -788,9 +757,25 @@ def test_repeated_ngram_blocking_end_to_end(self):
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_repeated_ngrams_step,
take_step=step_function,
)

step_function = get_step_function(repeated_ngram_transition_probabilities_1)
self.beam_search.max_steps = 5
expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]])
expected_log_probs = np.log(
np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])
)
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=step_function,
)

def test_repeated_ngram_blocking_end_to_end_bigrams(self):
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
self.beam_search.beam_size = 2

# Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place
self.beam_search.max_steps = 4
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)]
Expand All @@ -799,9 +784,13 @@ def test_repeated_ngram_blocking_end_to_end(self):
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_repeated_ngrams_step,
take_step=step_function,
)

def test_repeated_ngram_blocking_end_to_end_trigrams(self):
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
self.beam_search.beam_size = 2

# Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place
self.beam_search.max_steps = 5
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)]
Expand All @@ -810,7 +799,7 @@ def test_repeated_ngram_blocking_end_to_end(self):
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_repeated_ngrams_step,
take_step=step_function,
)

def test_repeated_ngram_blocking_end_indices(self):
Expand All @@ -820,12 +809,13 @@ def test_repeated_ngram_blocking_end_indices(self):
"""
# We block unigrams, but 5 (the end symbol) is repeated and it does not mess
# up the sequence's probability
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
self.beam_search.beam_size = 2
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)]
expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]])
expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_repeated_ngrams_step,
take_step=step_function,
)

0 comments on commit b72bbfc

Please sign in to comment.