Skip to content

Commit

Permalink
feat: Add min_top_k to TopPSampler (#8228)
Browse files Browse the repository at this point in the history
* Add feature to Top P Sampler

* Add release notes

* Fix zip call

* Fix mypy

* Restore doc string and make mypy happy hopefully

* Make mypy happy

* PR comment

* Revert change to make mypy happy

* Add back type ignore

* try to fix typing

* Update haystack/components/samplers/top_p.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/samplers/top_p.py

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 35b1215 commit 7fd0b6a
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 85 deletions.
115 changes: 75 additions & 40 deletions haystack/components/samplers/top_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from typing import List, Optional, Tuple

from haystack import ComponentError, Document, component, logging
from haystack import Document, component, logging
from haystack.lazy_imports import LazyImport

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,52 +42,58 @@ class TopPSampler:
```
"""

def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None):
def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None, min_top_k: Optional[int] = None):
"""
Creates an instance of TopPSampler.
:param top_p: Float between 0 and 1 representing the cumulative probability threshold for document selection.
A value of 1.0 indicates no filtering (all documents are retained).
:param score_field: Name of the field in each document's metadata that contains the score. If None, the default
document score field is used.
:param min_top_k: If specified, the minimum number of documents to return. If the top_p selects
fewer documents, additional ones with the next highest scores are added to the selection.
"""
torch_import.check()

self.top_p = top_p
if not 0 <= top_p <= 1:
raise ValueError(f"top_p must be between 0 and 1. Got {top_p}.")
self.score_field = score_field
self.min_top_k = min_top_k

@component.output_types(documents=List[Document])
def run(self, documents: List[Document], top_p: Optional[float] = None):
"""
Filters documents using top-p sampling based on their scores.
If the specified top_p results in no documents being selected (especially in cases of a low top_p value), the
method returns the document with the highest similarity score.
method returns the document with the highest score.
:param documents: List of Document objects to be filtered.
:param top_p: Optional. A float to override the cumulative probability threshold set during initialization.
:param top_p: If specified, a float to override the cumulative probability threshold set during initialization.
:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the top-p sampling.
:raises ValueError: If the top_p value is not within the range [0, 1].
"""
if not documents:
return {"documents": []}

top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None

top_p = top_p or self.top_p
if not 0 <= top_p <= 1:
raise ValueError(f"top_p must be between 0 and 1. Got {top_p}.")

similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32)
documents_with_scores, scores = self._get_documents_and_scores(documents)
if len(documents_with_scores) == 0:
logger.warning("No documents with scores found. Returning the original documents.")
return {"documents": documents}

# Apply softmax normalization to the similarity scores
probs = torch.nn.functional.softmax(similarity_scores, dim=-1)
sorted_docs_with_scores = sorted(zip(documents_with_scores, scores), key=lambda x: x[1], reverse=True)
sorted_documents, sorted_scores = [list(t) for t in zip(*sorted_docs_with_scores)]

# Sort the probabilities and calculate their cumulative sum
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
tensor_scores = torch.tensor(sorted_scores, dtype=torch.float32)
probs = torch.nn.functional.softmax(tensor_scores, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)

# Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance
close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6)
Expand All @@ -99,44 +105,73 @@ def run(self, documents: List[Document], top_p: Optional[float] = None):
top_p_indices = torch.where(torch.BoolTensor(condition))[0]

# Map the selected indices back to their original indices
original_indices = sorted_indices[top_p_indices]
selected_docs = [documents[i.item()] for i in original_indices]
selected_docs = [sorted_documents[i.item()] for i in top_p_indices]

# If low p resulted in no documents being selected, then
# return at least one document
if not selected_docs:
if self.min_top_k and len(selected_docs) < self.min_top_k:
selected_docs = sorted_documents[: self.min_top_k]

# If low p resulted in no documents being selected, then return at least one document
if len(selected_docs) == 0:
logger.warning(
"Top-p sampling with p={top_p} resulted in no documents being selected. "
"Returning the document with the highest similarity score.",
"Returning the document with the highest score.",
top_p=top_p,
)
highest_prob_indices = torch.argsort(probs, descending=True)
selected_docs = [documents[int(highest_prob_indices[0].item())]]
selected_docs = [sorted_documents[0]]

return {"documents": selected_docs}

def _collect_scores(self, documents: List[Document]) -> List[float]:
@staticmethod
def _get_doc_score(doc: Document, score_field: Optional[str] = None) -> Optional[float]:
"""
Collect the scores from the documents' metadata.
Get the score of a document.
:param doc: Document object.
:param score_field: Name of the field in the document's metadata that contains the score.
If None, the document score field is used.
:return: Score of the document.
"""
if score_field:
score = doc.meta.get(score_field)
else:
score = doc.score

if not isinstance(score, float):
score = None
return score

def _get_documents_and_scores(self, documents: List[Document]) -> Tuple[List[Document], List[float]]:
"""
Checks if documents have scores in their metadata or score field and returns the documents with scores.
:param documents: List of Documents.
:return: List of scores.
"""
if self.score_field:
missing_scores_docs = [d for d in documents if self.score_field not in d.meta]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Score field '{self.score_field}' not found in metadata of documents "
f"with IDs: {missing_scores_docs_ids}."
f"Make sure that all documents have a score field '{self.score_field}' in their metadata."
docs_with_scores = []
scores = []
docs_missing_scores = []
for doc in documents:
score = self._get_doc_score(doc=doc, score_field=self.score_field)
if score is None:
docs_missing_scores.append(doc)
else:
scores.append(score)
docs_with_scores.append(doc)

if len(docs_missing_scores) > 0:
missing_scores_docs_ids = [d.id for d in docs_missing_scores if d.id]
if self.score_field:
logger.warning(
"Score field '{score_field}' not found in metadata of documents with IDs: {doc_ids}."
"Make sure that all documents have a score field '{score_field_2}' in their metadata.",
score_field=self.score_field,
doc_ids=",".join(missing_scores_docs_ids),
score_field_2=self.score_field,
)
return [d.meta[self.score_field] for d in documents]
else:
missing_scores_docs = [d for d in documents if d.score is None]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't."
else:
logger.warning(
"Ensure all documents have a valid score value. These documents {doc_ids} are missing scores.",
doc_ids=",".join(missing_scores_docs_ids),
)
return [d.score for d in documents] # type: ignore ## because Document score is Optional
return docs_with_scores, scores
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The parameter min_top_k is added to the TopPSampler which sets the minimum number of documents to be returned when the top-p sampling algorithm results in fewer documents being selected. The documents with the next highest scores are added to the selection. This is useful when we want to guarantee a set number of documents will always be passed on, but allow the Top-P algorithm to still determine if more documents should be sent based on document score.
133 changes: 88 additions & 45 deletions test/components/samplers/test_top_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,130 @@
#
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List

import pytest

from haystack import Document, ComponentError
from haystack import Document
from haystack.components.samplers.top_p import TopPSampler


@pytest.fixture
def documents_with_score_field() -> List[Document]:
return [
Document(content="Sarajevo", meta={"similarity_score": 0.7}),
Document(content="Belgrade", meta={"similarity_score": 0.01}),
Document(content="Berlin", meta={"similarity_score": 0.001}),
]


@pytest.fixture
def documents_with_score() -> List[Document]:
return [
Document(content="Sarajevo", score=0.7),
Document(content="Belgrade", score=0.01),
Document(content="Berlin", score=0.001),
]


class TestTopPSampler:
def test_run_scores_from_metadata(self):
"""
Test if the component runs correctly with scores already in the metadata.
"""
def test_init_raises_value_error(self) -> None:
with pytest.raises(ValueError):
TopPSampler(top_p=2.0)

def test_run_raises_value_error(self, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(top_p=0.95)
with pytest.raises(ValueError):
sampler.run(documents=documents_with_score, top_p=2.0)

def test_run_score_field(self, documents_with_score_field: List[Document]) -> None:
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
docs = [
Document(content="Berlin", meta={"similarity_score": -10.6}),
Document(content="Belgrade", meta={"similarity_score": -8.9}),
Document(content="Sarajevo", meta={"similarity_score": -4.6}),
]
docs = documents_with_score_field
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 1
assert len(docs) == 2
assert docs[0].content == "Sarajevo"
assert docs[1].content == "Belgrade"

def test_run_scores(self):
"""
Test if the component runs correctly with scores in the Document score field.
"""
sampler = TopPSampler(top_p=0.99)
def test_run_score_field_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
sampler = TopPSampler(top_p=1.0, score_field="similarity_score")
docs = [
Document(content="Berlin", score=-10.6),
Document(content="Belgrade", score=-8.9),
Document(content="Sarajevo", score=-4.6),
Document(content="Sarajevo", meta={"similarity_score": 0.7}),
Document(content="Belgrade", meta={"similarity_score": 0.01}),
Document(content="Berlin", meta={"similarity_score": None}),
]
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 2
assert docs[0].content == "Sarajevo"
assert docs[1].content == "Belgrade"
assert "Score field" in caplog.text

def test_run(self, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(top_p=0.99)
docs = documents_with_score
random.shuffle(docs)
sorted_scores = sorted([doc.score for doc in docs], reverse=True)

# top_p = 0.99 will get the top 1 document
output = sampler.run(documents=docs)
docs_filtered = output["documents"]
assert len(docs_filtered) == 1
assert len(docs_filtered) == 2
assert docs_filtered[0].content == "Sarajevo"
assert docs_filtered[1].content == "Belgrade"

assert [doc.score for doc in docs_filtered] == sorted_scores[:1]
assert [doc.score for doc in docs_filtered] == sorted_scores[:2]

def test_run_scores_top_p_1(self):
"""
Test if the component runs correctly top_p=1.
"""
def test_run_top_p_1(self, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(top_p=1.0)
docs = [
Document(content="Berlin", score=-10.6),
Document(content="Belgrade", score=-8.9),
Document(content="Sarajevo", score=-4.6),
]

docs = documents_with_score
random.shuffle(docs)
output = sampler.run(documents=docs)
docs_filtered = output["documents"]
assert len(docs_filtered) == len(docs)
assert docs_filtered[0].content == "Sarajevo"

assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)

# Returns an empty list if no documents are provided
def test_run_top_p_0(self, caplog: pytest.LogCaptureFixture, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(top_p=0.0)
docs = documents_with_score
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 1
assert docs[0].content == "Sarajevo"
assert "Top-p sampling with p=" in caplog.text

def test_returns_empty_list_if_no_documents_are_provided(self):
def test_run_returns_empty_list_no_documents(self) -> None:
sampler = TopPSampler()
output = sampler.run(documents=[])
assert output["documents"] == []

def test_run_scores_no_metadata_present(self):
"""
Test if the component runs correctly with scores missing from the metadata yet being specified in the
score_field.
"""
def test_run_no_score_field(self, caplog: pytest.LogCaptureFixture, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
docs = documents_with_score
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 3
assert docs[0].content == "Sarajevo"
assert "Score field 'similarity_score' not found" in caplog.text

def test_run_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
sampler = TopPSampler(top_p=0.95)
docs = [
Document(content="Berlin", score=-10.6),
Document(content="Belgrade", score=-8.9),
Document(content="Sarajevo", score=-4.6),
Document(content="Sarajevo", score=0.7),
Document(content="Belgrade", score=0.01),
Document(content="Berlin", score=None),
]
with pytest.raises(ComponentError, match="Score field 'similarity_score' not found"):
sampler.run(documents=docs)
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 1
assert docs[0].content == "Sarajevo"
assert "Ensure all documents have a valid score value" in caplog.text

def test_run_min_top_k(self, documents_with_score: List[Document]) -> None:
sampler = TopPSampler(min_top_k=2, top_p=0.2)
docs = documents_with_score
output = sampler.run(documents=docs)
docs = output["documents"]
assert len(docs) == 2
assert docs[0].content == "Sarajevo"
assert docs[1].content == "Belgrade"

0 comments on commit 7fd0b6a

Please sign in to comment.