Skip to content

Commit

Permalink
Merge branch 'main' into updating_docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista authored Aug 22, 2024
2 parents fec92d2 + 0a1a64c commit f4ef2e7
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_notes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:

- name: Get release note files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: releasenotes/notes/*.yaml

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ jobs:

- name: Get changed files
id: files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: |
**/*.py
Expand Down
43 changes: 42 additions & 1 deletion haystack/components/converters/docx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
with LazyImport("Run 'pip install python-docx'") as docx_import:
import docx
from docx.document import Document as DocxDocument
from docx.text.paragraph import Paragraph


@dataclass
Expand Down Expand Up @@ -119,7 +120,7 @@ def run(
continue
try:
file = docx.Document(io.BytesIO(bytestream.data))
paragraphs = [para.text for para in file.paragraphs]
paragraphs = self._extract_paragraphs_with_page_breaks(file.paragraphs)
text = "\n".join(paragraphs)
except Exception as e:
logger.warning(
Expand All @@ -136,6 +137,46 @@ def run(

return {"documents": documents}

def _extract_paragraphs_with_page_breaks(self, paragraphs: List["Paragraph"]) -> List[str]:
"""
Extracts paragraphs from a DOCX file, including page breaks.
Page breaks (both soft and hard page breaks) are not automatically extracted by python-docx as '\f' chars.
This means we need to add them in ourselves, as done here. This allows the correct page number
to be associated with each document if the file contents are split, e.g. by DocumentSplitter.
:param paragraphs:
List of paragraphs from a DOCX file.
:returns:
List of strings (paragraph text fields) with all page breaks added in as '\f' characters.
"""
paragraph_texts = []
for para in paragraphs:
if para.contains_page_break:
para_text_w_page_breaks = ""
# Usually, just 1 page break exists, but could be more if paragraph is really long, so we loop over them
for pb_index, page_break in enumerate(para.rendered_page_breaks):
# Can only extract text from first paragraph page break, unfortunately
if pb_index == 0:
if page_break.preceding_paragraph_fragment:
para_text_w_page_breaks += page_break.preceding_paragraph_fragment.text
para_text_w_page_breaks += "\f"
if page_break.following_paragraph_fragment:
# following_paragraph_fragment contains all text for remainder of paragraph.
# However, if the remainder of the paragraph spans multiple page breaks, it won't include
# those later page breaks so we have to add them at end of text in the `else` block below.
# This is not ideal, but this case should be very rare and this is likely good enough.
para_text_w_page_breaks += page_break.following_paragraph_fragment.text
else:
para_text_w_page_breaks += "\f"

paragraph_texts.append(para_text_w_page_breaks)
else:
paragraph_texts.append(para.text)

return paragraph_texts

def _get_docx_metadata(self, document: "DocxDocument") -> DOCXMetadata:
"""
Get all relevant data from the 'core_properties' attribute from a DOCX Document.
Expand Down
115 changes: 75 additions & 40 deletions haystack/components/samplers/top_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from typing import List, Optional, Tuple

from haystack import ComponentError, Document, component, logging
from haystack import Document, component, logging
from haystack.lazy_imports import LazyImport

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,52 +42,58 @@ class TopPSampler:
```
"""

def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None):
def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None, min_top_k: Optional[int] = None):
"""
Creates an instance of TopPSampler.
:param top_p: Float between 0 and 1 representing the cumulative probability threshold for document selection.
A value of 1.0 indicates no filtering (all documents are retained).
:param score_field: Name of the field in each document's metadata that contains the score. If None, the default
document score field is used.
:param min_top_k: If specified, the minimum number of documents to return. If the top_p selects
fewer documents, additional ones with the next highest scores are added to the selection.
"""
torch_import.check()

self.top_p = top_p
if not 0 <= top_p <= 1:
raise ValueError(f"top_p must be between 0 and 1. Got {top_p}.")
self.score_field = score_field
self.min_top_k = min_top_k

@component.output_types(documents=List[Document])
def run(self, documents: List[Document], top_p: Optional[float] = None):
"""
Filters documents using top-p sampling based on their scores.
If the specified top_p results in no documents being selected (especially in cases of a low top_p value), the
method returns the document with the highest similarity score.
method returns the document with the highest score.
:param documents: List of Document objects to be filtered.
:param top_p: Optional. A float to override the cumulative probability threshold set during initialization.
:param top_p: If specified, a float to override the cumulative probability threshold set during initialization.
:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the top-p sampling.
:raises ValueError: If the top_p value is not within the range [0, 1].
"""
if not documents:
return {"documents": []}

top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None

top_p = top_p or self.top_p
if not 0 <= top_p <= 1:
raise ValueError(f"top_p must be between 0 and 1. Got {top_p}.")

similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32)
documents_with_scores, scores = self._get_documents_and_scores(documents)
if len(documents_with_scores) == 0:
logger.warning("No documents with scores found. Returning the original documents.")
return {"documents": documents}

# Apply softmax normalization to the similarity scores
probs = torch.nn.functional.softmax(similarity_scores, dim=-1)
sorted_docs_with_scores = sorted(zip(documents_with_scores, scores), key=lambda x: x[1], reverse=True)
sorted_documents, sorted_scores = [list(t) for t in zip(*sorted_docs_with_scores)]

# Sort the probabilities and calculate their cumulative sum
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
tensor_scores = torch.tensor(sorted_scores, dtype=torch.float32)
probs = torch.nn.functional.softmax(tensor_scores, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)

# Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance
close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6)
Expand All @@ -99,44 +105,73 @@ def run(self, documents: List[Document], top_p: Optional[float] = None):
top_p_indices = torch.where(torch.BoolTensor(condition))[0]

# Map the selected indices back to their original indices
original_indices = sorted_indices[top_p_indices]
selected_docs = [documents[i.item()] for i in original_indices]
selected_docs = [sorted_documents[i.item()] for i in top_p_indices]

# If low p resulted in no documents being selected, then
# return at least one document
if not selected_docs:
if self.min_top_k and len(selected_docs) < self.min_top_k:
selected_docs = sorted_documents[: self.min_top_k]

# If low p resulted in no documents being selected, then return at least one document
if len(selected_docs) == 0:
logger.warning(
"Top-p sampling with p={top_p} resulted in no documents being selected. "
"Returning the document with the highest similarity score.",
"Returning the document with the highest score.",
top_p=top_p,
)
highest_prob_indices = torch.argsort(probs, descending=True)
selected_docs = [documents[int(highest_prob_indices[0].item())]]
selected_docs = [sorted_documents[0]]

return {"documents": selected_docs}

def _collect_scores(self, documents: List[Document]) -> List[float]:
@staticmethod
def _get_doc_score(doc: Document, score_field: Optional[str] = None) -> Optional[float]:
"""
Collect the scores from the documents' metadata.
Get the score of a document.
:param doc: Document object.
:param score_field: Name of the field in the document's metadata that contains the score.
If None, the document score field is used.
:return: Score of the document.
"""
if score_field:
score = doc.meta.get(score_field)
else:
score = doc.score

if not isinstance(score, float):
score = None
return score

def _get_documents_and_scores(self, documents: List[Document]) -> Tuple[List[Document], List[float]]:
"""
Checks if documents have scores in their metadata or score field and returns the documents with scores.
:param documents: List of Documents.
:return: List of scores.
"""
if self.score_field:
missing_scores_docs = [d for d in documents if self.score_field not in d.meta]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Score field '{self.score_field}' not found in metadata of documents "
f"with IDs: {missing_scores_docs_ids}."
f"Make sure that all documents have a score field '{self.score_field}' in their metadata."
docs_with_scores = []
scores = []
docs_missing_scores = []
for doc in documents:
score = self._get_doc_score(doc=doc, score_field=self.score_field)
if score is None:
docs_missing_scores.append(doc)
else:
scores.append(score)
docs_with_scores.append(doc)

if len(docs_missing_scores) > 0:
missing_scores_docs_ids = [d.id for d in docs_missing_scores if d.id]
if self.score_field:
logger.warning(
"Score field '{score_field}' not found in metadata of documents with IDs: {doc_ids}."
"Make sure that all documents have a score field '{score_field_2}' in their metadata.",
score_field=self.score_field,
doc_ids=",".join(missing_scores_docs_ids),
score_field_2=self.score_field,
)
return [d.meta[self.score_field] for d in documents]
else:
missing_scores_docs = [d for d in documents if d.score is None]
if missing_scores_docs:
missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id]
raise ComponentError(
f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't."
else:
logger.warning(
"Ensure all documents have a valid score value. These documents {doc_ids} are missing scores.",
doc_ids=",".join(missing_scores_docs_ids),
)
return [d.score for d in documents] # type: ignore ## because Document score is Optional
return docs_with_scores, scores
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The parameter min_top_k is added to the TopPSampler which sets the minimum number of documents to be returned when the top-p sampling algorithm results in fewer documents being selected. The documents with the next highest scores are added to the selection. This is useful when we want to guarantee a set number of documents will always be passed on, but allow the Top-P algorithm to still determine if more documents should be sent based on document score.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixed an issue where page breaks were not being extracted from DOCX files.
5 changes: 5 additions & 0 deletions releasenotes/notes/docx-para-forwardref-31941f54ab3b679f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Use a forward reference for the `Paragraph` class in the `DOCXToDocument` converter
to prevent import errors.
10 changes: 10 additions & 0 deletions test/components/converters/test_docx_file_to_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def test_run_error_non_existent_file(self, test_files_path, docx_converter, capl
docx_converter.run(sources=paths)
assert "Could not read non_existing_file.docx" in caplog.text

def test_run_page_breaks(self, test_files_path, docx_converter):
"""
Test if the component correctly parses page breaks.
"""
paths = [test_files_path / "docx" / "sample_docx_2_page_breaks.docx"]
output = docx_converter.run(sources=paths)
docs = output["documents"]
assert len(docs) == 1
assert docs[0].content.count("\f") == 4

def test_mixed_sources_run(self, test_files_path, docx_converter):
"""
Test if the component runs correctly when mixed sources are provided.
Expand Down
Loading

0 comments on commit f4ef2e7

Please sign in to comment.