Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom beam search scorer argument in generate function #32097

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

GM07
Copy link

@GM07 GM07 commented Jul 19, 2024

What does this PR do?

Added an argument beam_search_scorer_class in generate() function to allow users to use beam search and grouped beam search using a custom beam search scorer.

Before that, the internal methods _beam_search() and _group_beam_search() had to be used to pass in a custom beam search scorer. However, this caused a lot of problems since the generate() function does a lot of preprocessing on the input (interleave the input_ids for example in the case of beam search). All that preprocessing had to be done manually Right now, this can be set directly in the generate() function by passing the type of the beam search scorer in the following way :

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

class CustomBeamSearchScorer(BeamSearchScorer):
    finalize_called = False
    process_called = False

    def __init__(self, test_args, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.test_args = test_args

    def process(self, *args, **kwargs):
        results = super().process(*args, **kwargs)
        # Do stuff
        return results

    def finalize(self, *args, **kwargs):
        results = super().finalize(*args, **kwargs)
        # Do stuff
        return results

# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}

outputs = model.generate(
    input_ids,
    num_beams=num_beams,
    min_length=5,
    eos_token_id=model.config.eos_token_id,
    beam_search_scorer_class=CustomBeamSearchScorer,
    beam_search_scorer_args={"test_args": True},
    **model_kwargs,
)

Why was this done that way

Initially, the beam_search_scorer was simply an object instead of the type. However, this could lead to inconsistencies between the parameters of the generation config and the beam search scorer. For example, the number of beams could have been set to 2 in the generate() method, but set to 4 when creating the scorer. By passing the type only, this allows the scorer to be created with the generation config inside the method (like before) and preventing any inconsistencies between the two objects.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante
@ArthurZucker

@GM07 GM07 force-pushed the generate-with-custom-beam-search-scorer branch from 6c6e18a to d23a29a Compare July 19, 2024 18:43
@GM07 GM07 force-pushed the generate-with-custom-beam-search-scorer branch from d23a29a to 5437234 Compare July 22, 2024 18:27
@ArthurZucker
Copy link
Collaborator

cc @zucchini-nlp and @gante

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GM07 !

We're currently refactoring generate and beam methods and would like to not introduce new changes unless the refactoring is done. Tracker for refactoring can be found at #30810

@gante
Copy link
Member

gante commented Jan 30, 2025

Hi @GM07 👋

Custom beam scoring will be enabled, but in a different format :) Beam search is being refactored first in #35802 , then we can add much simpler scoring functions

(This PR will not be accepted in its current form, but contributions for the new format are welcome 🤗 )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants