Skip to content

Commit

Permalink
feat: LLMMetadata Extractor (#125)
Browse files Browse the repository at this point in the history
* wip

* extracting from a initial range

* removing example

* range specification with abbreviations

* detecting wrongly specified ranges

* linting and lincesing

* disable too-many-locals pylint

* refactoring page range

* refactoring page range

* Update haystack_experimental/components/extractors/llm_metadata_extractor.py

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>

* fixing range bug and cleaning unused code

* range can also be provided at runtime

* improving tests

* Update haystack_experimental/components/extractors/llm_metadata_extractor.py

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>

* PR comments

* fixing types for mypy

* fixing pylint and mypy

* Small changes to run method, helps with pylint

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent 4e1b37a commit 142b646
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 32 deletions.
105 changes: 79 additions & 26 deletions haystack_experimental/components/extractors/llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.components.builders import PromptBuilder
from haystack.components.generators import AzureOpenAIGenerator, OpenAIGenerator
from haystack.components.preprocessors import DocumentSplitter
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_secrets_inplace

from haystack_experimental.util.utils import expand_page_range

with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack==1.0.2\"'") as amazon_bedrock_generator:
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator

Expand Down Expand Up @@ -97,7 +100,7 @@ class LLMMetadataExtractor:
Document(content="Hugging Face is a company founded in Paris, France and is known for its Transformers library")
]
extractor = LLMMetadataExtractor(prompt=NER_PROMPT, expected_keys=["entities"], generator=OpenAIGenerator(), input_text='input_text')
extractor = LLMMetadataExtractor(prompt=NER_PROMPT, expected_keys=["entities"], generator_api="openai", prompt_variable='input_text')
extractor.run(documents=docs)
>> {'documents': [
Document(id=.., content: 'deepset was founded in 2018 in Berlin, and is known for its Haystack framework',
Expand All @@ -112,45 +115,55 @@ class LLMMetadataExtractor:
}
>>
```
""" # noqa: E501
""" # noqa: E501

def __init__( # pylint: disable=R0917
def __init__( # pylint: disable=R0917
self,
prompt: str,
input_text: str,
prompt_variable: str,
expected_keys: List[str],
generator_api: Union[str,LLMProvider],
generator_api: Union[str, LLMProvider],
generator_api_params: Optional[Dict[str, Any]] = None,
page_range: Optional[List[Union[str, int]]] = None,
raise_on_failure: bool = False,
):
"""
Initializes the LLMMetadataExtractor.
:param prompt: The prompt to be used for the LLM.
:param input_text: The input text to be processed by the PromptBuilder.
:param prompt_variable: The variable in the prompt to be processed by the PromptBuilder.
:param expected_keys: The keys expected in the JSON output from the LLM.
:param generator_api: The API provider for the LLM.
:param generator_api: The API provider for the LLM. Currently supported providers are:
"openai", "openai_azure", "aws_bedrock", "google_vertex"
:param generator_api_params: The parameters for the LLM generator.
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
metadata from the first and third pages of each document. It also accepts printable range
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
11, 12. If None, metadata will be extracted from the entire document for each document in the
documents list.
This parameter is optional and can be overridden in the `run` method.
:param raise_on_failure: Whether to raise an error on failure to validate JSON output.
:returns:
"""
self.prompt = prompt
self.input_text = input_text
self.builder = PromptBuilder(prompt, required_variables=[input_text])
self.prompt_variable = prompt_variable
self.builder = PromptBuilder(prompt, required_variables=[prompt_variable])
self.raise_on_failure = raise_on_failure
self.expected_keys = expected_keys
self.generator_api = generator_api if isinstance(generator_api, LLMProvider)\
else LLMProvider.from_str(generator_api)
self.generator_api_params = generator_api_params or {}
self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params)
if self.input_text not in self.prompt:
raise ValueError(f"Input text '{self.input_text}' must be in the prompt.")
if self.prompt_variable not in self.prompt:
raise ValueError(f"Prompt variable '{self.prompt_variable}' must be in the prompt.")
self.splitter = DocumentSplitter(split_by="page", split_length=1)
self.expanded_range = expand_page_range(page_range) if page_range else None

@staticmethod
def _init_generator(
generator_api: LLMProvider,
generator_api_params: Optional[Dict[str, Any]]
generator_api: LLMProvider,
generator_api_params: Optional[Dict[str, Any]]
) -> Union[OpenAIGenerator, AzureOpenAIGenerator, "AmazonBedrockGenerator", "VertexAIGeminiGenerator"]:
"""
Initialize the chat generator based on the specified API provider and parameters.
Expand Down Expand Up @@ -215,7 +228,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
prompt=self.prompt,
input_text=self.input_text,
input_text=self.prompt_variable,
expected_keys=self.expected_keys,
raise_on_failure=self.raise_on_failure,
generator_api=self.generator_api.value,
Expand All @@ -240,28 +253,68 @@ def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"])
return default_from_dict(cls, data)

def _extract_metadata_and_update_doc(self, document: Document, content: str):
"""
Extract metadata from the content and updates the document's metadata with the extracted metadata.
If the extraction fails, i.e.: no JSON is returned by the LLM API, the error message will be stored in
`errors`.
:param document: Document to be updated with the extracted metadata.
:param content: Content to extract metadata from.
"""
prompt_with_doc = self.builder.run(
template=self.prompt,
template_variables={self.prompt_variable: content}
)
result = self.llm_provider.run(prompt=prompt_with_doc["prompt"])
llm_answer = result["replies"][0]
if self.is_valid_json_and_has_expected_keys(expected=self.expected_keys, received=llm_answer):
extracted_metadata = json.loads(llm_answer)
for k in self.expected_keys:
document.meta[k] = extracted_metadata[k]

@component.output_types(documents=List[Document], errors=Dict[str, Any])
def run(self, documents: List[Document]) -> Dict[str, Any]:
def run(self, documents: List[Document], page_range: Optional[List[Union[str, int]]] = None):
"""
Extract metadata from documents using a Language Model.
If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component
will split the documents into pages and extract metadata from the specified range of pages. The metadata will be
extracted from the entire document if `page_range` is not provided.
The original documents will be returned updated with the extracted metadata.
:param documents: List of documents to extract metadata from.
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
metadata from the first and third pages of each document. It also accepts printable range
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
11, 12.
If None, metadata will be extracted from the entire document for each document in the
documents list.
:returns:
A dictionary with the keys:
- "documents": List of documents with extracted metadata.
- "documents": The original list of documents updated with the extracted metadata.
- "errors": A dictionary with document IDs as keys and error messages as values.
"""
errors = {}

errors: Dict[str, Any] = {}
expanded_range = self.expanded_range
if page_range:
expanded_range = expand_page_range(page_range)

for document in documents:
prompt_with_doc = self.builder.run(input_text=document.content)
result = self.llm_provider.run(prompt=prompt_with_doc["prompt"])
llm_answer = result["replies"][0]
if self.is_valid_json_and_has_expected_keys(expected=self.expected_keys, received=llm_answer):
extracted_metadata = json.loads(llm_answer)
for k in self.expected_keys:
document.meta[k] = extracted_metadata[k]
if not document.content:
logger.warning(f"Document {document.id} has no content. Skipping metadata extraction.")
continue
if expanded_range:
pages = self.splitter.run(documents=[document])
content = ""
for idx, page in enumerate(pages["documents"]):
if idx + 1 in expanded_range:
content += page.content + "\f"
else:
errors[document.id] = llm_answer

# extract metadata from the entire document
content = document.content
self._extract_metadata_and_update_doc(document, content)
return {"documents": documents, "errors": errors}
43 changes: 43 additions & 0 deletions haystack_experimental/util/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union


def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
"""
Takes a list of page numbers and ranges and expands them into a list of page numbers.
For example, given a page_range=['1-3', '5', '8', '10-12'] the function will return [1, 2, 3, 5, 8, 10, 11, 12]
:param page_range: List of page numbers and ranges
:returns:
An expanded list of page integers
"""
expanded_page_range = []

for page in page_range:
if isinstance(page, int):
# check if it's a range wrongly passed as an integer expression
if "-" in str(page):
msg = "range must be a string in the format 'start-end'"
raise ValueError(f"Invalid page range: {page} - {msg}")
expanded_page_range.append(page)

elif isinstance(page, str) and page.isdigit():
expanded_page_range.append(int(page))

elif isinstance(page, str) and "-" in page:
start, end = page.split("-")
expanded_page_range.extend(range(int(start), int(end) + 1))

else:
msg = "range must be a string in the format 'start-end' or an integer"
raise ValueError(f"Invalid page range: {page} - {msg}")

if not expanded_page_range:
raise ValueError("No valid page numbers or ranges found in the input list")

return expanded_page_range
49 changes: 43 additions & 6 deletions test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def test_init_default(self, monkeypatch):
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
input_text="test"
prompt_variable="test"
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is False
assert extractor.input_text == "test"
assert extractor.prompt_variable == "test"

def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
Expand All @@ -36,7 +36,9 @@ def test_init_with_parameters(self, monkeypatch):
'model': 'gpt-3.5-turbo',
'generation_kwargs': {"temperature": 0.5}
},
input_text="test")
prompt_variable="test",
page_range=['1-5']
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is True
Expand All @@ -45,14 +47,25 @@ def test_init_with_parameters(self, monkeypatch):
'model': 'gpt-3.5-turbo',
'generation_kwargs': {"temperature": 0.5}
}
assert extractor.expanded_range == [1, 2, 3, 4, 5]

def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test2"
)

def test_to_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
input_text="test",
prompt_variable="test",
generator_api_params={'model': 'gpt-4o-mini', 'generation_kwargs': {"temperature": 0.5}},
raise_on_failure=True)
extractor_dict = extractor.to_dict()
Expand Down Expand Up @@ -84,7 +97,7 @@ def test_from_dict(self, monkeypatch):
'prompt': 'some prompt that was used with the LLM {{test}}',
'expected_keys': ['key1', 'key2'],
'raise_on_failure': True,
'input_text': 'test',
'prompt_variable': 'test',
'generator_api': 'openai',
'generator_api_params': {
'api_base_url': None,
Expand All @@ -103,6 +116,30 @@ def test_from_dict(self, monkeypatch):
assert extractor.prompt == "some prompt that was used with the LLM {{test}}"
assert extractor.generator_api == LLMProvider.OPENAI

def test_output_invalid_json_raise_on_failure_true(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="""{"json": "output"}""")

def test_output_valid_json_not_expected_keys(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="{'json': 'output'}")

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
Expand Down Expand Up @@ -147,7 +184,7 @@ def test_live_run(self):
"""

doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor(prompt=ner_prompt, expected_keys=["entities"], input_text="input_text", generator_api=LLMProvider.OPENAI)
extractor = LLMMetadataExtractor(prompt=ner_prompt, expected_keys=["entities"], prompt_variable="input_text", generator_api=LLMProvider.OPENAI)
writer = DocumentWriter(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_component("extractor", extractor)
Expand Down
22 changes: 22 additions & 0 deletions test/util/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from haystack_experimental.util.utils import expand_page_range

def test_expand_page_range_valid_input():
assert expand_page_range([1, 3]) == [1, 3]
assert expand_page_range(['1-3']) == [1, 2, 3]
assert expand_page_range(['1-3', 5, 8, '10-12']) == [1,2,3,5,8,10,11,12]
assert expand_page_range(['1-3', '5', '8', '10-12']) == [1, 2, 3, 5, 8, 10, 11, 12]
assert expand_page_range(['1-3', 5, 8, '10-12', '15-20', 50]) == [1,2,3,5,8,10,11,12,15,16,17,18,19,20,50]


def test_expand_page_range_invalid_input():

with pytest.raises(ValueError):
expand_page_range(['1-3', 'non_digit_string', 8, '10-12', '15-20', '50'])

with pytest.raises(ValueError):
expand_page_range([1-3, 5, 8])

with pytest.raises(ValueError):
expand_page_range([])

0 comments on commit 142b646

Please sign in to comment.