mbr adds Sampling-based Minimum Bayes Risk decoding to Hugging Face transformers. Originally proposed by Eikema & Aziz (2022), this technique is a risk-minimizing algorithm for generating text with a language model. This repository implements several optimizations for MBR decoding. Most notably, mbr introduces reference aggregation Vamvas & Sennrich (2024).
Pronounce: ember /ˈɛm.bɚ/
pip install mbr
Requirements:
- Python >= 3.9
- PyTorch
- Hugging Face transformers < 4.39
The main components of mbr are:
mbr.MBRGenerationMixin
: overrides a model'sgenerate
method to add MBR decoding.mbr.MBRGenerationConfig
: specifies the parameters of MBR decoding, e.g., the number of samples to generate and the metric to optimize.
Models need to inherit from MBRGenerationMixin
for MBR decoding to work. Here's two ways to achieve this, using the Llama model as an example:
Variant A:
from transformers import LlamaForCausalLM
from mbr import MBRGenerationMixin
class MBRLlamaForCausalLM(MBRGenerationMixin, LlamaForCausalLM):
pass
Then, you can use MBRLlamaForCausalLM
as you would use LlamaForCausalLM
:
model = MBRLlamaForCausalLM.from_pretrained(...)
Variant B:
from mbr import MBR
model = MBR(LlamaForCausalLM).from_pretrained(...)
Create an MBRConfig
object to pass to the model's generate
method:
from mbr import MBRConfig
mbr_config = MBRConfig(
num_samples=10,
metric="chrf",
)
Call the model's generate
method directly, or use the Pipeline API. Make sure to pass the mbr_config
, as well as the model's tokenizer.
from transformers import pipeline
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = generator("Hello,", mbr_config=mbr_config, tokenizer=tokenizer)
The following research papers, among many others, provide a description of Sampling-based Minimum Bayes Risk decoding:
- Sampling-Based Approximations to Minimum Bayes Risk Decoding for Neural Machine Translation (Eikema & Aziz, EMNLP 2022)
- Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation (Müller & Sennrich, ACL-IJCNLP 2021)
In practice, MBR decoding is most commonly implemented as follows (on the example of machine translation):
- Instead of searching for the single most probable output sequence (e.g., using beam search), generate a number of samples.
- Score each sample against the other samples using a metric (e.g., BLEU).
- Return the sample with the highest score. Intuitively, this can be seen as returning the median of all samples.
The terminology around MBR decoding varies:
Term used in this codebase | Alternative terms |
---|---|
samples | candidates, hypotheses |
references | pseudo-references, evidence |
metric score | expected utility (negative) expected risk, error |
The generation of the samples can be customized by passing a generation_config
to the generate
method or to the pipeline call:
from transformers import GenerationConfig
generation_config = GenerationConfig.from_pretrained("mymodel",
do_sample=True,
num_beams=1,
epsilon_cutoff=0.02,
)
model.generate(..., generation_config=generation_config)
By default, the samples themselves are used a references (or a subset of the samples if num_references
is smaller than num_samples
).
You could also sample the reference set independently, using a custom generation config for the references:
from transformers import GenerationConfig
references_config = GenerationConfig.from_pretrained("mymodel",
do_sample=True,
num_beams=1,
top_p=0.9,
)
model.generate(..., references_config=references_config)
By default, mbr uses fastChrF, which is optimized for efficient comparison of many samples to many references.
You can also plug in metrics from the Hugging Face Evaluate library.
A full list of metrics is found here. Some typical choices are:
To use a metric from Hugging Face, either specify the metric's name (e.g., "comet"
, "bleurt"
) or pass an evaluate.Metric
object directly.
Since different metrics output differently structured dicts, you need to specify the metric_output_field
that should be used as the metric score.
from evaluate import load
metric = load('bleu')
mbr_config = MBRGenerationConfig(
metric=metric,
metric_output_field="bleu", # the BLEU metric returns a dict with a "bleu" field
...
)
Internally, mbr will call the metric's compute
method to calculate the metric score for each sample.
By default, mbr will call compute
separately for each sample–reference pair.
Since this requires many compute
calls, it can make sense to optimize the metric computation. Different metrics will require different optimization strategies.
To override the default way of calling the metric, define a MetricRunner
class and pass it to the generate
method:
from mbr import MetricRunner
class MyMetricRunner(MetricRunner):
def __call__(self,
input_ids: torch.LongTensor,
sample_ids: Tuple[torch.LongTensor],
reference_ids: Tuple[torch.LongTensor],
) -> torch.FloatTensor:
... # TODO: implement your efficient metric computation here
model.generate(..., metric_runner=MyMetricRunner())
For COMET, an optimized implementation is already provided in CometMetricRunner
:
from mbr.metrics.comet import CometMetricRunner
mbr_config = MBRGenerationConfig(
...,
metric="comet",
metric_output_field="mean_score",
)
metric_runner = CometMetricRunner(mbr_config, tokenizer)
model.generate(..., metric_runner=metric_runner)
MBR decoding is notoriously slow. mbr implements some optimizations:
- Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling.
- Optimized ChrF metric: fastChrF is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust.
- Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references).
- Optimized COMET metric: Inspired by Amrhein & Sennrich (2022),
CometMetricRunner
caches sequence embeddings and reuses them for all pairwise comparisons. - Reference aggregation for COMET (Vamvas & Sennrich, 2024): Consider using
mbr.metrics.comet.AggregateCometMetricRunner
instead of the defaultCometMetricRunner
if you have many references.
The experiments directory contains the code for reproductions of experiments from the following papers:
- MBR for (low-resource) machine translation (Müller & Sennrich, 2021)
- MBR with neural metrics and epsilon sampling for machine translation (Freitag et al., 2023)
- MBR for summarization (Bertsch et al., 2023)
- https://github.com/roxot/mbr-nmt: Original implementation (demo)
- https://github.com/ZurichNLP/understanding-mbr: MBR with Sockeye
- https://github.com/ZurichNLP/mbr-sensitivity and https://github.com/Unbabel/COMET#minimum-bayes-risk-decoding: COMET metric for MBR
- https://github.com/rainavyas/mbr_gec: MBR for Grammatical Error Correction
-
v0.3.0 (draft)
- New feature: Reference Aggregation (Vamvas & Sennrich, 2024):
- Set fastChrF with reference aggregation as default metric
- Add
AggregateCometMetricRunner
to allow for reference aggregation with COMET
- Bugfix: Disable dropout for COMET metric
- New feature: Reference Aggregation (Vamvas & Sennrich, 2024):
-
v0.2.0
- Breaking change: Rename
MBRGenerationConfig
toMBRConfig
- Breaking change:
MetricRunner
now returns aMetricOutput
dict instead of the raw tensor of scores. - Make the size of the metric cache configurable via
MBRConfig.metric_cache_size
- Allow that the number of references can be larger than the number of samples (if generated separately from the samples).
- Remove
GenerationConfig
as parent class ofMBRConfig
- Breaking change: Rename
When using this code for research, please cite the following paper:
@misc{vamvas-sennrich-2024-linear,
title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation},
author={Jannis Vamvas and Rico Sennrich},
year={2024},
eprint={2402.04251},
archivePrefix={arXiv},
primaryClass={cs.CL}
}