Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Trainer with prompts and prompt masking #2964

Merged
merged 51 commits into from
Nov 8, 2024

Conversation

ArthurCamara
Copy link
Contributor

@ArthurCamara ArthurCamara commented Sep 27, 2024

Pull Request overview

  • Adds support to including prompts in the Trainer class
  • Supports masking the prompts in the Pooling when training.

Details

Currently, the encode method of SentenceTransformer supports adding prompts (or instructions) dynamically to the sentences by passing either prompt or prompt_name. However, this is not supported when training, as mentioned in #2945, as it uses the forward method instead.

This PR implements a similar functionality to the Trainer, by adding prompt parameter that can be:

  • str: The prompt will be appended to all sentences in the dataset
  • dict[str, str]: If the keys are column names, it will append the prompt to the respective column. If the training dataset is a dictionary of datasets, and the dictionary keys are names of the datasets, it will add the prompt to all the columns of the respective dataset.
  • dict[str, dict[str, str]]: Same as above, but assumes the first level is the dataset name and the second level are the column names.

As the prompts can be dynamic (changing for each dataset and column), they are injected in the sentences by the get_train|test|eval|_dataloader methods, by calling add_prompts_to_dataset, which solves for each dataset and column which prompt to inject.

Finally, the add_prompts_to_dataset also adds <column_name>_prompt_length columns that, when passed to Pooling method with include_prompt=False, will mask the instructions properly as well. (currently this is only explicitly for Instructor models, but can be set by the user by calling model.set_pooling_include_prompt(include_prompt=False)

@ArthurCamara ArthurCamara changed the title Trainer with prompt masking [feat] Trainer with prompts and prompt masking Sep 27, 2024
@tomaarsen tomaarsen force-pushed the trainer-with-prompt-masking branch 2 times, most recently from 354cb65 to bf9eb80 Compare September 30, 2024 15:36
@tomaarsen
Copy link
Collaborator

Hello!

Thanks for this PR. I rebased it to get rid of the leftover commits that aren't necessary here.
I have a few hesitations with the current approach, although I do quite like the idea of being able to specify prompts to use during training apart from manually adding them in your dataset(s). My current hesitations:

  1. Adding a column to the entire dataset before training will be incompatible with datasets IterableDataset (i.e., load_dataset("...", streaming=True).
  2. I'm in theory okay with adding a ..._prompt_length column: I recognize that it's crucial to get this information if include_prompt is False in the Pooling. However, I have two notes:
    • Could we e.g. only add the information if the Pooling module (if it exists) actually has include_prompt=False?
    • Could we perhaps not add an entire column, but instead create a nested dictionary with dataset names mapping to column names mapping to prompt lengths? Dataset names should be a column if there's multiple datasets.

Could we perhaps add the prompts (and prompt lengths) in the data collator? E.g. right here: https://github.com/ArthurCamara/sentence-transformers/blob/bf9eb803ce2dda26a8ef903c33d80cd1fcb55a3d/sentence_transformers/data_collator.py#L50-L56

The data collator knows the dataset name, the column name (see the snippet), and should then be able to use that information to "on the fly" prepend the prompts. In a perfect world we could even only tokenize the prompts once, but that gets complicated with padding and truncation, so it's better to keep it simpler.

I also like your idea that prompt can be multiple things: a single prompt, a prompt per column, or a prompt per column per dataset. I do think prompts is a bit better though, because that's what we use in the encode etc.

I'm curious to hear your thoughts on this.

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

Hello!

Thanks for this PR. I rebased it to get rid of the leftover commits that aren't necessary here. I have a few hesitations with the current approach, although I do quite like the idea of being able to specify prompts to use during training apart from manually adding them in your dataset(s). My current hesitations:

  1. Adding a column to the entire dataset before training will be incompatible with datasets IterableDataset (i.e., load_dataset("...", streaming=True).

  2. I'm in theory okay with adding a ..._prompt_length column: I recognize that it's crucial to get this information if include_prompt is False in the Pooling. However, I have two notes:

    • Could we e.g. only add the information if the Pooling module (if it exists) actually has include_prompt=False?
    • Could we perhaps not add an entire column, but instead create a nested dictionary with dataset names mapping to column names mapping to prompt lengths? Dataset names should be a column if there's multiple datasets.

Could we perhaps add the prompts (and prompt lengths) in the data collator? E.g. right here: https://github.com/ArthurCamara/sentence-transformers/blob/bf9eb803ce2dda26a8ef903c33d80cd1fcb55a3d/sentence_transformers/data_collator.py#L50-L56

The data collator knows the dataset name, the column name (see the snippet), and should then be able to use that information to "on the fly" prepend the prompts. In a perfect world we could even only tokenize the prompts once, but that gets complicated with padding and truncation, so it's better to keep it simpler.

This was one of the things I was considering, to change the Collator instead of the dataset itself. But I had issues with Accelerator and DDP before when the data was not exclusively tensors (i.e., strings), but I think we can walk around it within the collator. I will give it a shot and let you know.

I also like your idea that prompt can be multiple things: a single prompt, a prompt per column, or a prompt per column per dataset. I do think prompts is a bit better though, because that's what we use in the encode etc.

Agreed. =)

I'm curious to hear your thoughts on this.

  • Tom Aarsen

…/sentence-transformers into trainer-with-prompt-masking
@ArthurCamara ArthurCamara force-pushed the trainer-with-prompt-masking branch from 86dd847 to bf9eb80 Compare October 2, 2024 08:40
@JosephGatto
Copy link

Hi thanks for implementing this. Any guide on how to fine-tune with prompts?

@tomaarsen
Copy link
Collaborator

Hello!

Until this is integrated, I would recommend manually adding the prompts to your training datasets. E.g.:

from datasets import load_dataset
from typing import Dict, List, Any

def prepend_prompt(batch: Dict[str, List[Any]], prompts: Dict[str, str] | None = None) -> Dict[str, List[Any]]:
    if not prompts:
        return batch

    for column_name, prompt in prompts.items():
        batch[column_name] = [prompt + value for value in batch[column_name]]
    return batch

train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
train_dataset = train_dataset.map(
    prepend_prompt,
    batched=True,
    fn_kwargs={"prompts": {"question": "Represent this sentence for searching relevant passages: "}}
)
print(train_dataset[0])
# {'query': 'Represent this sentence for searching relevant passages: when did richmond last play in a preliminary final', 'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next."}

And the rest is the same as the normal training: https://sbert.net/docs/sentence_transformer/training_overview.html

  • Tom Aarsen

@JosephGatto
Copy link

Hey thanks so much for the quick reply. My main concern here would be if pooling is being done on just the text (and excluding the prompt). I believe in the INSTRUCTOR paper they do not include the embeddings of the prompt during mean pooling. Would this solution take care of that?

@tomaarsen
Copy link
Collaborator

Indeed, my solution only works if you're including the prompt in the pooling. If you're not, i.e. with setting this to False:

Then you must use this PR. You can use:

pip install git+https://github.com/ArthurCamara/sentence-transformers@trainer-with-prompt-masking

and then use the regular training with one extra parameter in the SentenceTransformerTrainer:

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    prompts={
        "query": "Represent this sentence for searching relevant passages: ",
    },
    evaluator=dev_evaluator,
)

The prompts can be:

  1. a prompt string
  2. a column name to prompt mapping
  3. a dataset to prompt mapping (if you use a dataset dict to train on multiple datasets simultaneously)
  4. a dataset to column name to prompt mapping (i.e. nested dicts, only if you use a dataset dict to train on multiple datasets simultaneously)

I do want to warn you that I'm about to fully overhaul this PR, although the usage will remain the same.

  • Tom Aarsen

@JosephGatto
Copy link

Thanks so much. And if I was interested in training with dynamic prompts (unique prompt per sample) would that be possible with the methods you described?

@tomaarsen
Copy link
Collaborator

Unique per sample is not possible here without subclassing the Trainer, no. You could use a unique sample per dataset, if that helps. I didn't think that a unique prompt per sample was a notable use case, so I didn't think to integrate it.

@JosephGatto
Copy link

Got it. Thank you!

@tomaarsen
Copy link
Collaborator

tomaarsen commented Nov 5, 2024

Heya @ArthurCamara,

I've overhauled the prompt prepending once more, as I still had some slight concerns with the previous implementations after some experimentation. You have worked on 2 implementations, and I'm now proposing a third as well:

  1. 'Greedily' .map over each dataset to add the prompt string to each dataset.
  2. 'Lazily' prepends prompts in the data collator in the Trainer.
  3. Use .set_transform for Dataset(Dict) and .map for IterableDataset(Dict) to add the prompt string to each dataset.

I had concerns with the first two:

  1. I'm wary that this results in large memory usage and/or cache files.
  2. During model card generation, I sample from the datasets to use in the model card (example), I'd also like for the prompts to be included, which isn't the case if the prompt prepending only exists in the Trainer.

After getting some valuable recommendations by the Datasets team and @lhoestq in particular, I'm now using .set_transform and .map to lazily apply 1) the prompts (if provided), 2) prompt lengths (if needed for pooling), and 3) dataset name (if needed for determining which loss to use). The implementation now lives as a 1-time "update" of the provided train/eval datasets, so the model card can easily fetch samples that include the prompts.

I've also trained 2 near-identical models:

The former consistently performs slightly worse than the model with the prompts:
image

Also, the prompts model shows the prompts in the model card easily: https://huggingface.co/tomaarsen/mpnet-base-nq-prompts#natural-questions

Lastly, I built an extensive training suite for this feature because there are a LOT of moving parts between training, evaluation, iterable datasets, and the various prompt formats.

I'm curious about your thoughts on my proposal @ArthurCamara, as I know you're using this yourself too! And one final question:

  • Do you think that prompts should be a parameter in SentenceTransformerTrainer or in SentenceTransformerTrainingArguments?
  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

Heya @ArthurCamara,

I've overhauled the prompt prepending once more, as I still had some slight concerns with the previous implementations after some experimentation. You have worked on 2 implementations, and I'm now proposing a third as well:

  1. 'Greedily' .map over each dataset to add the prompt string to each dataset.
  2. 'Lazily' prepends prompts in the data collator in the Trainer.
  3. Use .set_transform for Dataset(Dict) and .map for IterableDataset(Dict) to add the prompt string to each dataset.

I had concerns with the first two:

  1. I'm wary that this results in large memory usage and/or cache files.
  2. During model card generation, I sample from the datasets to use in the model card (example), I'd also like for the prompts to be included, which isn't the case if the prompt prepending only exists in the Trainer.

Adding the prompts to the model card is something very useful that I haven't thought of. Nice.

After getting some valuable recommendations by the Datasets team and @lhoestq in particular, I'm now using .set_transform and .map to lazily apply 1) the prompts (if provided), 2) prompt lengths (if needed for pooling), and 3) dataset name (if needed for determining which loss to use). The implementation now lives as a 1-time "update" of the provided train/eval datasets, so the model card can easily fetch samples that include the prompts.

Nice to learn something new. Didn't know about set_transform This is a cleaner solution than doing multiple passes over the datasets.

I've also trained 2 near-identical models:

Also, the prompts model shows the prompts in the model card easily: https://huggingface.co/tomaarsen/mpnet-base-nq-prompts#natural-questions

Neat. I like the way prompting helps to disentangle the representations of query and documents even in smaller models

Lastly, I built an extensive training suite for this feature because there are a LOT of moving parts between training, evaluation, iterable datasets, and the various prompt formats.

I'm curious about your thoughts on my proposal @ArthurCamara, as I know you're using this yourself too! And one final question:

  • Do you think that prompts should be a parameter in SentenceTransformerTrainer or in SentenceTransformerTrainingArguments?

Good question. I want to say it should be in the Arguments, so it can be easily swapped out when testing with different configurations. But I'm not sure how of a good UX it will be to pass a double-nested dictionary as an argument to training script (of course, reading from a json/yaml file is also an option).

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 7be3eac into UKPLab:master Nov 8, 2024
9 checks passed
@tomaarsen
Copy link
Collaborator

Thanks a bunch for spearheading this. I didn't expect that the prompts would have such a notable impact (0.66% and 0.90% relative NDCG@10 across mpnet-base and bert-base-uncased, respectively), but I'm glad that they do.

This will be included as one of the 4 major features in Monday's v3.3 release, alongside the NanoBEIREvaluator which will be another major feature. I really appreciate your work on these.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants