Skip to content

Commit

Permalink
Constrained Beam Search [without disjunctive decoding] (huggingface#1…
Browse files Browse the repository at this point in the history
…5416)

* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
cwkeam and patrickvonplaten authored Feb 9, 2022
1 parent 0113aae commit 2b5603f
Show file tree
Hide file tree
Showing 8 changed files with 1,871 additions and 16 deletions.
19 changes: 17 additions & 2 deletions docs/source/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], and
[`~generation_utils.GenerationMixin.group_beam_search`].
[`~generation_utils.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`].

Most of those are only useful if you are studying the code of the generate methods in the library.

Expand Down Expand Up @@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
[[autodoc]] MaxTimeCriteria
- __call__

## Constraints

A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.

[[autodoc]] Constraint

[[autodoc]] PhrasalConstraint

[[autodoc]] ConstraintListState

## BeamSearch

[[autodoc]] BeamScorer
Expand All @@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
- process
- finalize

[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize

## Utilities

[[autodoc]] top_k_top_p_filtering
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,12 @@
"TextDatasetForNextSentencePrediction",
]
_import_structure["deepspeed"] = []
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
_import_structure["generation_beam_constraints"] = [
"Constraint",
"ConstraintListState",
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
Expand Down Expand Up @@ -2750,7 +2755,8 @@
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
Expand Down
Loading

0 comments on commit 2b5603f

Please sign in to comment.